Skip to main content

Distributed Training

The torch.distributed package supports distributed training across multiple processes and machines. It provides communication primitives and high-level APIs for data parallelism and model parallelism.

Initialization

init_process_group

torch.distributed.init_process_group(
    backend,
    init_method=None,
    timeout=default_pg_timeout,
    world_size=-1,
    rank=-1,
    store=None,
    group_name='',
    pg_options=None
)
Initializes the default distributed process group.
backend
str
The backend to use. Valid values include 'nccl', 'gloo', 'mpi'.
init_method
str
default:"None"
URL specifying how to initialize the process group. If not specified, 'env://' is used.
world_size
int
default:"-1"
Number of processes participating in the job. Required if store is specified.
rank
int
default:"-1"
Rank of the current process. Required if store is specified.
timeout
timedelta
default:"default_pg_timeout"
Timeout for operations executed against the process group.

is_initialized

torch.distributed.is_initialized()
Checks if the default process group has been initialized.
Returns
bool
True if the default process group has been initialized, False otherwise.

get_rank

torch.distributed.get_rank(group=None)
Returns the rank of the current process in the provided group.
group
ProcessGroup
default:"None"
The process group to work on. If None, the default process group is used.
Returns
int
The rank of the process group (-1 if not part of the group).

get_world_size

torch.distributed.get_world_size(group=None)
Returns the number of processes in the current process group.
group
ProcessGroup
default:"None"
The process group to work on.
Returns
int
The world size of the process group.

Communication Primitives

send

torch.distributed.send(tensor, dst, group=None, tag=0)
Sends a tensor synchronously.
tensor
Tensor
Tensor to send.
dst
int
Destination rank.
group
ProcessGroup
default:"None"
The process group to work on.

recv

torch.distributed.recv(tensor, src=None, group=None, tag=0)
Receives a tensor synchronously.
tensor
Tensor
Tensor to fill with received data.
src
int
default:"None"
Source rank. If None, will receive from any process.

broadcast

torch.distributed.broadcast(
    tensor,
    src,
    group=None,
    async_op=False
)
Broadcasts the tensor to the whole group.
tensor
Tensor
Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise.
src
int
Source rank.
async_op
bool
default:"False"
Whether this op should be an async op.

all_reduce

torch.distributed.all_reduce(
    tensor,
    op=ReduceOp.SUM,
    group=None,
    async_op=False
)
Reduces the tensor data across all machines in such a way that all get the final result.
tensor
Tensor
Input and output of the collective. The function operates in-place.
op
ReduceOp
default:"ReduceOp.SUM"
One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.

reduce

torch.distributed.reduce(
    tensor,
    dst,
    op=ReduceOp.SUM,
    group=None,
    async_op=False
)
Reduces the tensor data across all machines.
tensor
Tensor
Input and output of the collective.
dst
int
Destination rank.

all_gather

torch.distributed.all_gather(
    tensor_list,
    tensor,
    group=None,
    async_op=False
)
Gathers tensors from the whole group in a list.
tensor_list
list[Tensor]
Output list. It should contain correctly-sized tensors to be used for output of the collective.
tensor
Tensor
Tensor to be broadcast from current process.

gather

torch.distributed.gather(
    tensor,
    gather_list=None,
    dst=0,
    group=None,
    async_op=False
)
Gathers a list of tensors in a single process.
tensor
Tensor
Input tensor.
gather_list
list[Tensor]
default:"None"
List of appropriately-sized tensors to use for received data (required in the receiving process).
dst
int
default:"0"
Destination rank.

scatter

torch.distributed.scatter(
    tensor,
    scatter_list=None,
    src=0,
    group=None,
    async_op=False
)
Scatters a list of tensors to all processes in a group.
tensor
Tensor
Output tensor.
scatter_list
list[Tensor]
default:"None"
List of tensors to scatter (required in the source process).
src
int
default:"0"
Source rank.

barrier

torch.distributed.barrier(group=None, async_op=False, device_ids=None)
Synchronizes all processes. This collective blocks processes until the whole group enters this function.
group
ProcessGroup
default:"None"
The process group to work on.
async_op
bool
default:"False"
Whether this op should be an async op.

High-Level APIs

DistributedDataParallel

torch.nn.parallel.DistributedDataParallel(
    module,
    device_ids=None,
    output_device=None,
    dim=0,
    broadcast_buffers=True,
    process_group=None,
    bucket_cap_mb=25,
    find_unused_parameters=False,
    check_reduction=False,
    gradient_as_bucket_view=False,
    static_graph=False
)
Implements distributed data parallelism at the module level.
module
Module
Module to be parallelized.
device_ids
list[int or torch.device]
default:"None"
CUDA devices for single-device modules.
find_unused_parameters
bool
default:"False"
Whether to find unused parameters. Useful for models with conditional execution.
gradient_as_bucket_view
bool
default:"False"
Whether gradients should be views into the DDP bucket.

Distributed Samplers

DistributedSampler

torch.utils.data.distributed.DistributedSampler(
    dataset,
    num_replicas=None,
    rank=None,
    shuffle=True,
    seed=0,
    drop_last=False
)
Sampler that restricts data loading to a subset of the dataset. Especially useful in conjunction with DistributedDataParallel.
dataset
Dataset
Dataset to be sampled.
num_replicas
int
default:"None"
Number of processes participating in distributed training.
rank
int
default:"None"
Rank of the current process within num_replicas.
shuffle
bool
default:"True"
If True, sampler will shuffle the indices.

Example Usage

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def setup(rank, world_size):
    """Initialize the distributed environment."""
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:29500',
        world_size=world_size,
        rank=rank
    )

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    setup(rank, world_size)
    
    # Create a tensor
    tensor = torch.ones(2, 2).to(rank)
    
    # All-reduce operation
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print(f'Rank {rank}: {tensor}')
    
    cleanup()

if __name__ == '__main__':
    world_size = 2
    mp.spawn(demo_basic, args=(world_size,), nprocs=world_size)

Backends

  • Recommended for NVIDIA GPUs
  • Best performance for GPU-to-GPU communication
  • Supports collective operations on CUDA tensors
  • Supports both CPU and GPU
  • Good for CPU-based distributed training
  • Cross-platform support (Linux, macOS, Windows)
  • Requires MPI implementation (e.g., OpenMPI)
  • Good for HPC environments
  • Supports both CPU and GPU