Skip to main content
PyTorch provides powerful distributed training capabilities to scale your models across multiple GPUs and machines. This guide covers DistributedDataParallel (DDP) and FullyShardedDataParallel (FSDP) for efficient multi-GPU training.

DistributedDataParallel (DDP)

DDP is the recommended approach for multi-GPU training in PyTorch. It replicates your model across multiple processes, with each process owning a dedicated GPU.

Basic DDP Setup

1

Initialize Process Group

First, initialize the distributed process group. This establishes communication between processes.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize the process group
dist.init_process_group(
    backend='nccl',  # Use 'nccl' for GPU, 'gloo' for CPU
    init_method='env://',  # Read from environment variables
    world_size=4,  # Total number of processes
    rank=0  # Rank of this process (0 to world_size-1)
)
The nccl backend is optimized for GPU communication and provides the best performance for distributed GPU training.
2

Wrap Your Model

Wrap your model with DDP to enable distributed training.
# Create model and move to GPU
model = MyModel().to(device)

# Wrap with DDP
ddp_model = DDP(
    model,
    device_ids=[local_rank],
    output_device=local_rank,
    broadcast_buffers=True,
    find_unused_parameters=False
)
3

Train with DDP

Training loop remains similar to single-GPU training.
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        outputs = ddp_model(batch)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

DDP Configuration Options

DDP provides several configuration options to optimize performance:
# Control gradient bucketing for AllReduce operations
ddp_model = DDP(
    model,
    bucket_cap_mb=25,  # Default bucket size in MB
    broadcast_buffers=True,
    find_unused_parameters=False
)

Launch Distributed Training

Use torch.distributed.launch or torchrun to launch multiple processes:
# Launch on single node with 4 GPUs
torchrun --nproc_per_node=4 train.py

# Launch on multiple nodes
torchrun \
    --nnodes=2 \
    --nproc_per_node=4 \
    --master_addr="192.168.1.1" \
    --master_port=29500 \
    train.py

FullyShardedDataParallel (FSDP)

FSDP shards model parameters, gradients, and optimizer states across GPUs, enabling training of models that don’t fit on a single GPU.

FSDP Setup

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

# Define mixed precision policy
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16
)

# Auto-wrap policy for nested modules
auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=10000  # Wrap modules with 10K+ params
)

# Wrap model with FSDP
fsdp_model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mp_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    device_id=torch.cuda.current_device()
)

FSDP Sharding Strategies

Shards parameters, gradients, and optimizer states. Provides maximum memory savings.
sharding_strategy=ShardingStrategy.FULL_SHARD
Memory: Lowest | Speed: Moderate
Shards gradients and optimizer states, keeps parameters replicated.
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
Memory: Moderate | Speed: Faster than FULL_SHARD
No sharding, equivalent to DDP. Use when you don’t need memory savings.
sharding_strategy=ShardingStrategy.NO_SHARD
Memory: Highest | Speed: Fastest
Shard within a node, replicate across nodes. Best for multi-node training.
sharding_strategy=ShardingStrategy.HYBRID_SHARD
Memory: Balanced | Speed: Optimized for multi-node

FSDP2 (New API)

PyTorch now offers a simplified FSDP2 API:
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy

# Apply FSDP to specific layers
for layer in model.layers:
    fully_shard(layer)

# Apply to entire model
fully_shard(model)

# With mixed precision
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32
)
fully_shard(model, mixed_precision=mp_policy)
FSDP2 provides a more composable API and is recommended for new projects. It allows per-layer sharding control.

Distributed Collectives

PyTorch provides communication primitives for distributed operations:

All-Reduce

import torch.distributed as dist

# Sum tensors across all processes
tensor = torch.tensor([1.0, 2.0, 3.0]).cuda()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
# Result: each process has the sum of all tensors

Broadcast

# Broadcast tensor from rank 0 to all processes
tensor = torch.randn(3, 4).cuda()
dist.broadcast(tensor, src=0)

Gather and Scatter

# Gather tensors from all processes
local_tensor = torch.randn(2, 3).cuda()
gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]
dist.all_gather(gathered, local_tensor)

Best Practices

Common Pitfalls to Avoid:
  • Don’t use find_unused_parameters=True unless necessary (it adds overhead)
  • Always use DistributedSampler to partition data across processes
  • Synchronize before validation/testing with dist.barrier()
  • Use gradient accumulation carefully with DDP (gradients auto-sync on backward)

Checkpoint Saving

# Only save from rank 0
if dist.get_rank() == 0:
    torch.save({
        'epoch': epoch,
        'model_state_dict': ddp_model.module.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'checkpoint.pt')

Data Loading

from torch.utils.data.distributed import DistributedSampler

# Create distributed sampler
sampler = DistributedSampler(
    dataset,
    num_replicas=world_size,
    rank=rank,
    shuffle=True
)

# Use sampler in DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=4,
    pin_memory=True
)

# Set epoch for proper shuffling
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)
    for batch in dataloader:
        # training loop
        pass

Performance Optimization

Gradient Compression

from torch.distributed.algorithms.ddp_comm_hooks import default_hooks

# Apply FP16 compression to gradients
ddp_model.register_comm_hook(
    state=None,
    hook=default_hooks.fp16_compress_hook
)

Overlapping Computation and Communication

# DDP automatically overlaps gradient AllReduce with backward pass
# You can tune bucket size for better overlap
ddp_model = DDP(
    model,
    bucket_cap_mb=25,  # Smaller = more overlap, more overhead
    gradient_as_bucket_view=True
)