Scale PyTorch models across multiple GPUs and machines with DDP and FSDP
PyTorch provides powerful distributed training capabilities to scale your models across multiple GPUs and machines. This guide covers DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP) for efficient multi-GPU training.
DDP is the recommended approach for multi-GPU training in PyTorch. It replicates your model across multiple processes, with each process owning a dedicated GPU.
First, initialize the distributed process group. This establishes communication between processes.
import torchimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDP# Initialize the process groupdist.init_process_group( backend='nccl', # Use 'nccl' for GPU, 'gloo' for CPU init_method='env://', # Read from environment variables world_size=4, # Total number of processes rank=0 # Rank of this process (0 to world_size-1))
The nccl backend is optimized for GPU communication and provides the best performance for distributed GPU training.
2
Wrap Your Model
Wrap your model with DDP to enable distributed training.
# Create model and move to GPUmodel = MyModel().to(device)# Wrap with DDPddp_model = DDP( model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=True, find_unused_parameters=False)
3
Train with DDP
Training loop remains similar to single-GPU training.
for epoch in range(num_epochs): for batch in dataloader: optimizer.zero_grad() outputs = ddp_model(batch) loss = criterion(outputs, targets) loss.backward() optimizer.step()
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy# Apply FSDP to specific layersfor layer in model.layers: fully_shard(layer)# Apply to entire modelfully_shard(model)# With mixed precisionmp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32)fully_shard(model, mixed_precision=mp_policy)
FSDP2 provides a more composable API and is recommended for new projects. It allows per-layer sharding control.
import torch.distributed as dist# Sum tensors across all processestensor = torch.tensor([1.0, 2.0, 3.0]).cuda()dist.all_reduce(tensor, op=dist.ReduceOp.SUM)# Result: each process has the sum of all tensors
# Gather tensors from all processeslocal_tensor = torch.randn(2, 3).cuda()gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]dist.all_gather(gathered, local_tensor)
from torch.utils.data.distributed import DistributedSampler# Create distributed samplersampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=True)# Use sampler in DataLoaderdataloader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True)# Set epoch for proper shufflingfor epoch in range(num_epochs): sampler.set_epoch(epoch) for batch in dataloader: # training loop pass
# DDP automatically overlaps gradient AllReduce with backward pass# You can tune bucket size for better overlapddp_model = DDP( model, bucket_cap_mb=25, # Smaller = more overlap, more overhead gradient_as_bucket_view=True)