Skip to main content Home Skills Coding & Development at-dispatch-v2 Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.
bunx add-skill pytorch/pytorch -s at-dispatch-v2 autograd deep-learning gpu machine-learning neural-network numpy
AT_DISPATCH to AT_DISPATCH_V2 Converter
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in aten/src/ATen/Dispatch_v2.h.
When to use this skill
Use this skill when:
Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
Porting ATen kernels to use the new dispatch API
Working with files in aten/src/ATen/native/ that use dispatch macros
User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
Quick reference
Old format:
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
// lambda body
});
New format:
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
// lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
Key transformations
Reorder arguments : scalar_type and come first, then lambda, then types
name
Wrap the lambda : Use AT_WRAP(lambda) to handle internal commas
Expand type groups : Use AT_EXPAND(AT_ALL_TYPES) instead of implicit expansion
List individual types : Add extra types (kHalf, kBFloat16, etc.) after expanded groups
Add include : #include <ATen/Dispatch_v2.h> near other Dispatch includes
Instructions
Step 1: Add the Dispatch_v2.h include Add the v2 header near the existing #include <ATen/Dispatch.h>:
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
Keep the old Dispatch.h include for now (other code may still need it).
Step 2: Identify the old dispatch pattern Common patterns to convert:
AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)
AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)
Step 3: Map the old macro to type groups Identify which type group macro corresponds to the base types:
Old macro base AT_DISPATCH_V2 type group ALL_TYPESAT_EXPAND(AT_ALL_TYPES)FLOATING_TYPESAT_EXPAND(AT_FLOATING_TYPES)INTEGRAL_TYPESAT_EXPAND(AT_INTEGRAL_TYPES)COMPLEX_TYPESAT_EXPAND(AT_COMPLEX_TYPES)ALL_TYPES_AND_COMPLEXAT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)
For combined patterns, use multiple AT_EXPAND() entries:
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
Step 4: Extract the individual types From AT_DISPATCH_*_AND2(type1, type2, ...) or AT_DISPATCH_*_AND3(type1, type2, type3, ...), extract the individual types (type1, type2, etc.).
These become the trailing arguments after the type group:
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
^^^^^^^^^^^^^^^^^^^^^^^^
Individual types from AND3
Step 5: Transform to AT_DISPATCH_V2 Apply the transformation:
AT_DISPATCH_V2(
scalar_type, // 1st: The dtype expression
"name", // 2nd: The debug string
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
type_groups, // 4th+: Type groups with AT_EXPAND()
individual_types // Last: Individual types
)
// BEFORE
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool,
iter.dtype(),
"min_values_cuda",
[&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}
);
// AFTER
AT_DISPATCH_V2(
iter.dtype(),
"min_values_cuda",
AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
Step 6: Handle multi-line lambdas For lambdas with internal commas or complex expressions, AT_WRAP is essential:
AT_DISPATCH_V2(
dtype,
"complex_kernel",
AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
);
}),
AT_EXPAND(AT_ALL_TYPES)
);
Step 7: Verify the conversion
Type group reference Available type group macros (use with AT_EXPAND()):
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES // kDouble, kFloat
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES // Float8 variants
Common patterns
Pattern: AT_DISPATCH_ALL_TYPES_AND2 // Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>(data);
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
Pattern: AT_DISPATCH_FLOATING_TYPES_AND3 // Before
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
tensor.scalar_type(), "float_op", [&] {
process<scalar_t>(tensor);
});
// After
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2 // Before
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kHalf,
self.scalar_type(),
"complex_op",
[&] {
result = compute<scalar_t>(self);
}
);
// After
AT_DISPATCH_V2(
self.scalar_type(),
"complex_op",
AT_WRAP([&] {
result = compute<scalar_t>(self);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kComplexHalf,
kHalf
);
Edge cases
Case 1: No extra types (rare) // Before
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
Case 2: Many individual types (AND4, AND5, etc.) // Before
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
Case 3: Lambda with no captures // Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
static_kernel<scalar_t>();
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
Benefits of AT_DISPATCH_V2
No arity in macro name : Don't need different macros for AND2, AND3, AND4
Composable type sets : Mix and match type groups with AT_EXPAND()
Extensible : Easy to add more types without hitting macro limits
Clearer : Type groups are explicit, not implicit in macro name
Important notes
Keep #include <ATen/Dispatch.h> - other code may need it
The AT_WRAP() is mandatory - prevents comma parsing issues in the lambda
Type groups need AT_EXPAND(), individual types don't
The v2 API is in aten/src/ATen/Dispatch_v2.h - refer to it for full docs
See the header file for the Python script to regenerate the macro implementation
Workflow When asked to convert AT_DISPATCH macros:
Read the file to identify all AT_DISPATCH uses
Add #include <ATen/Dispatch_v2.h> if not present
For each dispatch macro:
Identify the pattern and extract components
Map the base type group
Extract individual types
Construct the AT_DISPATCH_V2 call
Apply with Edit tool
Show the user the complete converted file
Explain what was changed
Do NOT compile or test the code - focus on accurate conversion only.