Skip to main content
torch.compile is PyTorch’s JIT compiler that optimizes your models for faster execution. Introduced in PyTorch 2.0, it uses TorchDynamo to capture computation graphs and TorchInductor to generate optimized kernels.

Quick Start

Optimizing a model is as simple as wrapping it with torch.compile:
import torch

# Create your model
model = MyModel().cuda()

# Compile it!
compiled_model = torch.compile(model)

# Use it like normal
output = compiled_model(input)
The first run with torch.compile will be slower due to compilation overhead. Subsequent runs will be significantly faster.

How torch.compile Works

1

Graph Capture (TorchDynamo)

TorchDynamo intercepts Python bytecode and captures the computation graph:
# Your code
x = torch.randn(10, 10)
y = x + 1
z = y * 2

# TorchDynamo captures this as a graph:
# add(x, 1) -> mul(result, 2)
2

Optimization (Compiler Backend)

The graph is optimized using a backend (default: TorchInductor):
  • Operator fusion
  • Memory layout optimization
  • Kernel optimization
3

Code Generation

Optimized code is generated and cached for future use:
# First run: compiles (slower)
output = compiled_model(input)  # ~100ms

# Subsequent runs: uses cached version (faster)
output = compiled_model(input)  # ~10ms

Compilation Modes

torch.compile supports different optimization modes:
# Good balance of compilation time and runtime performance
compiled_model = torch.compile(model, mode='default')

Mode Comparison

ModeCompilation TimeRuntime SpeedUse Case
defaultModerateFastGeneral purpose
reduce-overheadFastModerateDevelopment, debugging
max-autotuneSlowFastestProduction, inference

Compilation Backends

Choose different backends for specific use cases:
# Default: TorchInductor (recommended)
compiled = torch.compile(model, backend='inductor')

# ONNX Runtime
compiled = torch.compile(model, backend='onnxrt')

# Eager mode (no compilation, for debugging)
compiled = torch.compile(model, backend='eager')

# Custom backend
compiled = torch.compile(model, backend=my_custom_backend)
Generates optimized Triton/C++ kernels. Best for NVIDIA GPUs and CPUs.Pros: Fast, good GPU utilization, automatic kernel fusion
Cons: Initial compilation overhead
Exports to ONNX and runs with ONNX Runtime.Pros: Cross-platform, good for deployment
Cons: Limited operator support

Dynamic Shapes

By default, torch.compile assumes static input shapes. For dynamic shapes:
import torch

# Enable dynamic shapes
compiled_model = torch.compile(
    model,
    dynamic=True  # Allow varying input shapes
)

# Now works with different shapes
out1 = compiled_model(torch.randn(32, 3, 224, 224))
out2 = compiled_model(torch.randn(64, 3, 224, 224))  # Different batch size
Dynamic shapes can reduce optimization opportunities. Use static shapes when possible for best performance.

Constrained Dynamic Shapes

Specify constraints for better optimization:
# Only batch dimension is dynamic, others are static
x = torch.randn(32, 3, 224, 224)

compiled_model = torch.compile(
    model,
    dynamic=True,
    options={
        "shape_padding": True,  # Pad shapes to multiples
        "max_autotune": True
    }
)

Full vs Partial Graph Compilation

Control what gets compiled:

Full Graph Mode (Default)

# Compile entire model
compiled_model = torch.compile(model, fullgraph=True)
Requires the entire model to be compilable. Fails if any unsupported operations are found.

Partial Graph Mode

# Compile what's possible, fall back to eager for rest
compiled_model = torch.compile(model, fullgraph=False)
Compiles supported subgraphs and falls back to eager execution for unsupported operations.

Disabling Compilation

Selectively disable compilation for parts of your code:
import torch._dynamo as dynamo

@dynamo.disable
def debug_function(x):
    # This function won't be compiled
    print(f"Debug: {x.shape}")
    return x

class MyModel(nn.Module):
    def forward(self, x):
        x = self.layer1(x)
        x = debug_function(x)  # Not compiled
        x = self.layer2(x)     # Compiled
        return x

compiled_model = torch.compile(MyModel())

Compilation with Training

Use torch.compile in training loops:
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())

# Compile the model
compiled_model = torch.compile(model, mode='reduce-overhead')

for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        
        # Compiled forward pass
        output = compiled_model(batch)
        loss = criterion(output, target)
        
        # Backward pass is also optimized
        loss.backward()
        optimizer.step()
Both forward and backward passes are optimized by torch.compile. Gradient computation is automatically included in the compiled graph.

Distributed Training

Combine torch.compile with DDP/FSDP:
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

# Initialize distributed
dist.init_process_group(backend='nccl')

# Compile first, then wrap with DDP
model = MyModel().cuda()
compiled_model = torch.compile(model)
ddp_model = DDP(compiled_model, device_ids=[local_rank])

# Training loop
for batch in dataloader:
    output = ddp_model(batch)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
Compilation Order Matters:
  1. Compile model with torch.compile
  2. Then wrap with DDP/FSDP
Wrapping in the wrong order may reduce optimization effectiveness.

Mixed Precision with torch.compile

from torch.amp import autocast, GradScaler

model = torch.compile(MyModel().cuda())
scaler = GradScaler(device='cuda')

for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast(device_type='cuda', dtype=torch.bfloat16):
        output = model(batch)  # Compiled + AMP
        loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Debugging Compilation

View Compilation Graphs

import torch._dynamo as dynamo

# Enable verbose logging
dynamo.config.verbose = True

# Compile and see what's happening
compiled_model = torch.compile(model)
output = compiled_model(input)

Explain Compilation

See what got compiled:
explanation = torch._dynamo.explain(model)(input)
print(explanation)
Output shows:
  • Which parts were compiled
  • Which parts fell back to eager
  • Why compilation failed for certain operations

Reset Compilation Cache

import torch._dynamo as dynamo

# Clear compilation cache
dynamo.reset()

# Force recompilation
output = compiled_model(input)

Performance Tips

When to Use torch.compile

# Models with static shapes
compiled = torch.compile(vision_model)

# Repetitive operations (training loops)
for batch in dataloader:
    output = compiled_model(batch)

# Production inference
with torch.no_grad():
    predictions = compiled_model(input)

Reduce Compilation Overhead

# Use smaller models during development
model = SmallModel()
compiled = torch.compile(model, mode='reduce-overhead')

# Switch to max performance for production
compiled = torch.compile(model, mode='max-autotune')

Optimize for Specific Hardware

# CUDA-specific optimizations
compiled = torch.compile(
    model,
    mode='max-autotune',
    options={
        'triton.cudagraphs': True,  # Enable CUDA graphs
        'shape_padding': True,       # Pad to optimal sizes
    }
)

Limitations and Workarounds

Current Limitations:
  • Some operations cause graph breaks (e.g., print(), pdb)
  • Data-dependent control flow may not compile
  • Mutation of external state can cause issues
  • First run is slower due to compilation

Handling Graph Breaks

# Avoid: Python operations that break graphs
def forward(self, x):
    print(x.shape)  # Graph break!
    return self.layer(x)

# Better: Remove debug statements or use torch operations
def forward(self, x):
    # Use torch operations instead
    return self.layer(x)

Reduce Recompilations

# Pad inputs to consistent shapes
def pad_to_multiple(x, multiple=8):
    batch, seq_len = x.shape
    padded_len = ((seq_len + multiple - 1) // multiple) * multiple
    return F.pad(x, (0, padded_len - seq_len))

# Use padded inputs
input_padded = pad_to_multiple(input)
output = compiled_model(input_padded)

Benchmarking

Compare compiled vs eager performance:
import time

model = MyModel().cuda()
compiled_model = torch.compile(model)
input = torch.randn(32, 3, 224, 224).cuda()

# Warmup
for _ in range(10):
    _ = compiled_model(input)

# Benchmark
start = time.time()
for _ in range(100):
    output = compiled_model(input)
torch.cuda.synchronize()
compiled_time = time.time() - start

# Compare with eager
start = time.time()
for _ in range(100):
    output = model(input)
torch.cuda.synchronize()
eager_time = time.time() - start

print(f"Speedup: {eager_time / compiled_time:.2f}x")