Write Metal/MPS kernels for PyTorch operators. Use when adding MPS device support to operators, implementing Metal shaders, or porting CUDA kernels to Apple Silicon. Covers native_functions.yaml dispatch, host-side operators, and Metal kernel implementation.
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
Overview
There are two workflows covered by this skill:
Adding new MPS support - Implementing a new operator from scratch
Migrating from MPSGraph - Converting existing MPSGraph-based operators to native Metal
Both workflows involve:
Update dispatch in aten/src/ATen/native/native_functions.yaml
Write Metal kernel in aten/src/ATen/native/mps/kernels/
Implement host-side stub in aten/src/ATen/native/mps/operations/
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Define operation functor
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* your operation */;
}
};
// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
Binary Kernel Pattern
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* your operation */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
Binary Kernel Type Registration Macros
For binary operations, use the convenience macros defined in BinaryKernel.metal:
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
Math functions (atan2, copysign, logaddexp): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INT2FLOAT_BINARY_OP
Comparison/logical ops (maximum, minimum): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OP
Arithmetic ops (add, sub, mul): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OP
Example for atan2 (supports both float and int inputs):
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
With Scalar Parameter
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
Type-Specialized Functor
struct special_functor {
// Floating point types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Use precise math
}
// Integral types
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Complex types (float2 for cfloat, half2 for chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = real, x.y = imaginary
return T(/* real */, /* imag */);
}
};
Note on complex types: Complex numbers in Metal are represented as vector types:
c10::complex<float> maps to float2 (x = real, y = imaginary)
c10::complex<half> maps to half2
Use is_complex_v<T> to specialize for complex types in functors.
Stub Registration Pattern (Preferred for Native Metal)
For structured kernels that use the TensorIterator pattern:
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal
}
// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
When migrating from MPSGraph, also remove the old implementation:
Remove from BinaryOps.mm (or UnaryOps.mm):
Delete the TORCH_IMPL_FUNC(my_op_out_mps) implementation
Remove the corresponding #include <ATen/ops/my_op_native.h> header
Add to BinaryKernel.mm (or UnaryKernel.mm):
Add the static kernel function
Add the REGISTER_DISPATCH call
Step 4: Compile
After making changes, compile to verify everything builds correctly:
cd build && ninja torch_cpu
Testing
Basic operator support is already tested by test_output_match in test/test_mps.py. After implementing an operator, enable testing by removing expected failures:
1. Remove from common_mps.py
Location:torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
# Remove entries like:
MPS_XFAILLIST = {
"my_op": ..., # Remove this line
}
MPS_SKIPLIST = {
"my_op": ..., # Remove this line
}
2. Remove from OpInfo decorators
Location:torch/testing/_internal/common_methods_invocations.py (or related files)