Create custom PyTorch operations and C++/CUDA extensions
PyTorch allows you to extend its functionality by creating custom operations. This is essential when you need operations not available in PyTorch’s standard library or want to optimize performance-critical code paths.
class CustomMatMul(Function): @staticmethod def forward(ctx, a, b): # Save both inputs ctx.save_for_backward(a, b) return torch.matmul(a, b) @staticmethod def backward(ctx, grad_output): a, b = ctx.saved_tensors # Compute gradients for both inputs grad_a = torch.matmul(grad_output, b.t()) grad_b = torch.matmul(a.t(), grad_output) return grad_a, grad_b# Usagedef custom_matmul(a, b): return CustomMatMul.apply(a, b)
Return one gradient per input in backward(). Use None for inputs that don’t require gradients.
import torchfrom torch import Tensor# Define the custom operation@torch._custom_ops.custom_op("mylib::numpy_sin")def numpy_sin(x: Tensor) -> Tensor: # This is just a prototype, implementation comes next raise NotImplementedError# Register CPU implementation@torch._custom_ops.impl("mylib::numpy_sin", device_types="cpu")def numpy_sin_impl_cpu(x): import numpy as np return torch.from_numpy(np.sin(x.numpy()))# Register CUDA implementation@torch._custom_ops.impl("mylib::numpy_sin", device_types="cuda")def numpy_sin_impl_cuda(x): import numpy as np return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)# Usagex = torch.randn(10)result = torch.ops.mylib.numpy_sin(x)
@torch._custom_ops.impl_abstract("mylib::numpy_sin")def numpy_sin_abstract(x): # Return metadata without computing the result return torch.empty_like(x)
Multiple Device Implementations
# Register for multiple devices@torch._custom_ops.impl("mylib::custom_add", device_types=["cpu", "cuda"])def custom_add_impl(x, y): return x + y# Or device-specific implementations@torch._custom_ops.impl("mylib::optimized_conv", device_types="cuda")def optimized_conv_cuda(input, weight): # Custom CUDA implementation return cuda_optimized_conv(input, weight)@torch._custom_ops.impl("mylib::optimized_conv", device_types="cpu")def optimized_conv_cpu(input, weight): # Fallback CPU implementation return torch.conv2d(input, weight)
#include <torch/extension.h>torch::Tensor custom_add(torch::Tensor a, torch::Tensor b) { return a + b;}// Python bindingsPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_add", &custom_add, "Custom addition");}
from torch.utils.cpp_extension import load# Compile at import timecustom_ops = load( name="custom_ops", sources=["custom_ops.cpp"], verbose=True)# Use immediatelyresult = custom_ops.custom_add(a, b)
JIT compilation is convenient for development but adds startup overhead. Use setup.py installation for production.