Optimize PyTorch models with compilation for faster execution
torch.compile is PyTorch’s JIT compiler that optimizes your models for faster execution. Introduced in PyTorch 2.0, it uses TorchDynamo to capture computation graphs and TorchInductor to generate optimized kernels.
Optimizing a model is as simple as wrapping it with torch.compile:
import torch# Create your modelmodel = MyModel().cuda()# Compile it!compiled_model = torch.compile(model)# Use it like normaloutput = compiled_model(input)
The first run with torch.compile will be slower due to compilation overhead. Subsequent runs will be significantly faster.
Selectively disable compilation for parts of your code:
import torch._dynamo as dynamo@dynamo.disabledef debug_function(x): # This function won't be compiled print(f"Debug: {x.shape}") return xclass MyModel(nn.Module): def forward(self, x): x = self.layer1(x) x = debug_function(x) # Not compiled x = self.layer2(x) # Compiled return xcompiled_model = torch.compile(MyModel())
model = MyModel().cuda()optimizer = torch.optim.Adam(model.parameters())# Compile the modelcompiled_model = torch.compile(model, mode='reduce-overhead')for epoch in range(num_epochs): for batch in dataloader: optimizer.zero_grad() # Compiled forward pass output = compiled_model(batch) loss = criterion(output, target) # Backward pass is also optimized loss.backward() optimizer.step()
Both forward and backward passes are optimized by torch.compile. Gradient computation is automatically included in the compiled graph.
from torch.nn.parallel import DistributedDataParallel as DDPimport torch.distributed as dist# Initialize distributeddist.init_process_group(backend='nccl')# Compile first, then wrap with DDPmodel = MyModel().cuda()compiled_model = torch.compile(model)ddp_model = DDP(compiled_model, device_ids=[local_rank])# Training loopfor batch in dataloader: output = ddp_model(batch) loss = criterion(output, target) loss.backward() optimizer.step()
Compilation Order Matters:
Compile model with torch.compile
Then wrap with DDP/FSDP
Wrapping in the wrong order may reduce optimization effectiveness.
# Use smaller models during developmentmodel = SmallModel()compiled = torch.compile(model, mode='reduce-overhead')# Switch to max performance for productioncompiled = torch.compile(model, mode='max-autotune')