Skip to main content
PyTorch allows you to extend its functionality by creating custom operations. This is essential when you need operations not available in PyTorch’s standard library or want to optimize performance-critical code paths.

Why Custom Operations?

Use cases:
  • Implement novel operations from research papers
  • Optimize critical bottlenecks with custom CUDA kernels
  • Interface with external libraries (C++, Fortran)
  • Create domain-specific operations
  • Achieve better performance than PyTorch’s generic implementations

Python Custom Operations

The simplest way to create custom ops using PyTorch’s Python API:

Using torch.autograd.Function

1

Define Forward and Backward

Create a class inheriting from torch.autograd.Function:
import torch
from torch.autograd import Function

class CustomReLU(Function):
    @staticmethod
    def forward(ctx, input):
        # Save tensors for backward pass
        ctx.save_for_backward(input)
        # Compute output
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        input, = ctx.saved_tensors
        # Compute gradient
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
2

Create Helper Function

Wrap the Function in a user-friendly function:
def custom_relu(input):
    return CustomReLU.apply(input)

# Usage
x = torch.randn(10, requires_grad=True)
y = custom_relu(x)
y.sum().backward()

Multiple Inputs and Outputs

class CustomMatMul(Function):
    @staticmethod
    def forward(ctx, a, b):
        # Save both inputs
        ctx.save_for_backward(a, b)
        return torch.matmul(a, b)
    
    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        
        # Compute gradients for both inputs
        grad_a = torch.matmul(grad_output, b.t())
        grad_b = torch.matmul(a.t(), grad_output)
        
        return grad_a, grad_b

# Usage
def custom_matmul(a, b):
    return CustomMatMul.apply(a, b)
Return one gradient per input in backward(). Use None for inputs that don’t require gradients.

Custom Op Registration API

PyTorch 2.0+ provides a new API for registering custom operations:

Basic Registration

import torch
from torch import Tensor

# Define the custom operation
@torch._custom_ops.custom_op("mylib::numpy_sin")
def numpy_sin(x: Tensor) -> Tensor:
    # This is just a prototype, implementation comes next
    raise NotImplementedError

# Register CPU implementation
@torch._custom_ops.impl("mylib::numpy_sin", device_types="cpu")
def numpy_sin_impl_cpu(x):
    import numpy as np
    return torch.from_numpy(np.sin(x.numpy()))

# Register CUDA implementation
@torch._custom_ops.impl("mylib::numpy_sin", device_types="cuda")
def numpy_sin_impl_cuda(x):
    import numpy as np
    return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)

# Usage
x = torch.randn(10)
result = torch.ops.mylib.numpy_sin(x)

Register Abstract Implementation

Define shape and dtype inference:
@torch._custom_ops.impl_abstract("mylib::numpy_sin")
def numpy_sin_abstract(x):
    # Return metadata without computing the result
    return torch.empty_like(x)
# Register for multiple devices
@torch._custom_ops.impl("mylib::custom_add", device_types=["cpu", "cuda"])
def custom_add_impl(x, y):
    return x + y

# Or device-specific implementations
@torch._custom_ops.impl("mylib::optimized_conv", device_types="cuda")
def optimized_conv_cuda(input, weight):
    # Custom CUDA implementation
    return cuda_optimized_conv(input, weight)

@torch._custom_ops.impl("mylib::optimized_conv", device_types="cpu")
def optimized_conv_cpu(input, weight):
    # Fallback CPU implementation
    return torch.conv2d(input, weight)

C++ Extensions

Create custom C++ operations for better performance:

Simple C++ Extension

1

Write C++ Implementation

Create custom_ops.cpp:
#include <torch/extension.h>

torch::Tensor custom_add(torch::Tensor a, torch::Tensor b) {
    return a + b;
}

// Python bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("custom_add", &custom_add, "Custom addition");
}
2

Create setup.py

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
    name='custom_ops',
    ext_modules=[
        CppExtension(
            name='custom_ops',
            sources=['custom_ops.cpp'],
        ),
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)
3

Build and Use

python setup.py install
import custom_ops

a = torch.randn(10)
b = torch.randn(10)
result = custom_ops.custom_add(a, b)

JIT Compilation (Load Inline)

Compile C++ code on-the-fly:
from torch.utils.cpp_extension import load

# Compile at import time
custom_ops = load(
    name="custom_ops",
    sources=["custom_ops.cpp"],
    verbose=True
)

# Use immediately
result = custom_ops.custom_add(a, b)
JIT compilation is convenient for development but adds startup overhead. Use setup.py installation for production.

CUDA Extensions

Write custom CUDA kernels for maximum performance:

CUDA Kernel Example

#include <torch/extension.h>
#include <cuda_runtime.h>

// CUDA kernel
__global__ void relu_kernel(const float* input, float* output, int size) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size) {
        output[idx] = input[idx] > 0 ? input[idx] : 0;
    }
}

// C++ wrapper
torch::Tensor relu_cuda(torch::Tensor input) {
    auto output = torch::empty_like(input);
    
    int size = input.numel();
    int threads = 256;
    int blocks = (size + threads - 1) / threads;
    
    relu_kernel<<<blocks, threads>>>(
        input.data_ptr<float>(),
        output.data_ptr<float>(),
        size
    );
    
    return output;
}

Optimized CUDA Kernel

Use shared memory and coalesced access:
__global__ void optimized_matmul_kernel(
    const float* A, const float* B, float* C,
    int M, int N, int K
) {
    __shared__ float shared_A[TILE_SIZE][TILE_SIZE];
    __shared__ float shared_B[TILE_SIZE][TILE_SIZE];
    
    int row = blockIdx.y * TILE_SIZE + threadIdx.y;
    int col = blockIdx.x * TILE_SIZE + threadIdx.x;
    
    float sum = 0.0f;
    
    for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; ++tile) {
        // Load tiles into shared memory
        if (row < M && tile * TILE_SIZE + threadIdx.x < K)
            shared_A[threadIdx.y][threadIdx.x] = A[row * K + tile * TILE_SIZE + threadIdx.x];
        else
            shared_A[threadIdx.y][threadIdx.x] = 0.0f;
        
        if (col < N && tile * TILE_SIZE + threadIdx.y < K)
            shared_B[threadIdx.y][threadIdx.x] = B[(tile * TILE_SIZE + threadIdx.y) * N + col];
        else
            shared_B[threadIdx.y][threadIdx.x] = 0.0f;
        
        __syncthreads();
        
        // Compute partial sum
        for (int k = 0; k < TILE_SIZE; ++k) {
            sum += shared_A[threadIdx.y][k] * shared_B[k][threadIdx.x];
        }
        
        __syncthreads();
    }
    
    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

Autograd Support for Custom Ops

Add gradient support to C++ operations:
#include <torch/extension.h>

class CustomFunction : public torch::autograd::Function<CustomFunction> {
public:
    static torch::Tensor forward(
        torch::autograd::AutogradContext* ctx,
        torch::Tensor input
    ) {
        ctx->save_for_backward({input});
        return input.clamp_min(0);
    }
    
    static std::vector<torch::Tensor> backward(
        torch::autograd::AutogradContext* ctx,
        std::vector<torch::Tensor> grad_outputs
    ) {
        auto saved = ctx->get_saved_variables();
        auto input = saved[0];
        
        auto grad_output = grad_outputs[0];
        auto grad_input = grad_output.clone();
        grad_input.masked_fill_(input < 0, 0);
        
        return {grad_input};
    }
};

torch::Tensor custom_relu(torch::Tensor input) {
    return CustomFunction::apply(input);
}

Triton Kernels

Use OpenAI Triton for GPU kernels without writing CUDA:
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    
    tl.store(output_ptr + offsets, output, mask=mask)

def triton_add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    n_elements = x.numel()
    
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    
    return output

# Usage
a = torch.randn(10000, device='cuda')
b = torch.randn(10000, device='cuda')
c = triton_add(a, b)
Triton automatically optimizes memory access patterns and handles many low-level details. It’s easier than CUDA but still very fast.

Testing Custom Operations

Gradient Checking

from torch.autograd import gradcheck

# Test custom operation gradients
input = torch.randn(10, 10, dtype=torch.double, requires_grad=True)

# gradcheck compares numerical and analytical gradients
test = gradcheck(custom_relu, input, eps=1e-6, atol=1e-4)
print(f"Gradient check passed: {test}")

Performance Benchmarking

import time

def benchmark(func, input, num_iterations=1000):
    # Warmup
    for _ in range(10):
        _ = func(input)
    
    torch.cuda.synchronize()
    start = time.time()
    
    for _ in range(num_iterations):
        output = func(input)
    
    torch.cuda.synchronize()
    elapsed = time.time() - start
    
    return elapsed / num_iterations

# Compare custom vs PyTorch implementation
input = torch.randn(1000, 1000, device='cuda')

custom_time = benchmark(custom_relu, input)
torch_time = benchmark(torch.relu, input)

print(f"Custom: {custom_time*1000:.3f}ms")
print(f"PyTorch: {torch_time*1000:.3f}ms")
print(f"Speedup: {torch_time/custom_time:.2f}x")

Best Practices

Custom Op Checklist:
  • ✓ Implement gradient checking for autograd functions
  • ✓ Handle edge cases (empty tensors, non-contiguous memory)
  • ✓ Add type checking and error handling
  • ✓ Benchmark against PyTorch’s built-in operations
  • ✓ Test on both CPU and GPU
  • ✓ Document expected input shapes and dtypes

Error Handling

torch::Tensor safe_custom_op(torch::Tensor input) {
    TORCH_CHECK(input.dim() == 2, "Expected 2D tensor");
    TORCH_CHECK(input.dtype() == torch::kFloat32, "Expected float32");
    TORCH_CHECK(input.is_contiguous(), "Expected contiguous tensor");
    
    // Implementation
    return result;
}

Type Dispatch

torch::Tensor custom_op(torch::Tensor input) {
    return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "custom_op", [&] {
        return custom_op_impl<scalar_t>(input);
    });
}

Integration with torch.compile

Make custom ops work with torch.compile:
from torch.library import Library, impl

# Register custom op
my_lib = Library("mylib", "DEF")
my_lib.define("custom_relu(Tensor x) -> Tensor")

@impl(my_lib, "custom_relu", "CPU")
def custom_relu_cpu(x):
    return x.clamp_min(0)

@impl(my_lib, "custom_relu", "CUDA")
def custom_relu_cuda(x):
    return cuda_relu(x)

# Now works with torch.compile
model = torch.compile(MyModel())