Skip to main content

Overview

This guide covers the essential components of training neural networks in PyTorch: loss functions, optimizers, and the training loop.

Training Loop Structure

A typical PyTorch training loop follows this pattern:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Setup
model = YourModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Training loop
model.train()  # Set to training mode
for epoch in range(num_epochs):
    running_loss = 0.0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
1

Zero gradients

Call optimizer.zero_grad() to clear gradients from the previous iteration
2

Forward pass

Compute model predictions and calculate the loss
3

Backward pass

Call loss.backward() to compute gradients via backpropagation
4

Update parameters

Call optimizer.step() to update model weights

Loss Functions

PyTorch provides loss functions in torch.nn for various tasks.

Classification Losses

# For multi-class classification
# Combines LogSoftmax and NLLLoss
criterion = nn.CrossEntropyLoss()

# Model output: raw logits [batch_size, num_classes]
outputs = model(inputs)  # Shape: [32, 10]
targets = torch.randint(0, 10, (32,))  # Class indices

loss = criterion(outputs, targets)

Regression Losses

# Mean Squared Error Loss
criterion = nn.MSELoss()

outputs = model(inputs)  # Predicted values
targets = torch.randn(32, 1)  # Ground truth

loss = criterion(outputs, targets)

Advanced Losses

# Negative Log Likelihood Loss
nll_loss = nn.NLLLoss()

# KL Divergence Loss (for distributions)
kl_loss = nn.KLDivLoss(reduction='batchmean')

# Cosine Embedding Loss
cos_loss = nn.CosineEmbeddingLoss()

# Triplet Margin Loss (for metric learning)
triplet_loss = nn.TripletMarginLoss(margin=1.0)

Complete Training Example

Here’s a complete training example with validation:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class Trainer:
    def __init__(self, model, train_loader, val_loader, device='cuda'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        # Loss and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []
    
    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        
        for inputs, targets in self.train_loader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
        
        avg_loss = running_loss / len(self.train_loader)
        return avg_loss
    
    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in self.val_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                running_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        avg_loss = running_loss / len(self.val_loader)
        accuracy = 100. * correct / total
        
        return avg_loss, accuracy
    
    def train(self, num_epochs):
        for epoch in range(num_epochs):
            # Train
            train_loss = self.train_epoch()
            self.train_losses.append(train_loss)
            
            # Validate
            val_loss, val_acc = self.validate()
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_acc)
            
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Train Loss: {train_loss:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print('-' * 50)

# Usage
trainer = Trainer(model, train_loader, val_loader)
trainer.train(num_epochs=10)

Mixed Precision Training

Use automatic mixed precision (AMP) for faster training with lower memory:
from torch.cuda.amp import autocast, GradScaler

model = YourModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()

for epoch in range(num_epochs):
    for inputs, targets in train_loader:
        inputs, targets = inputs.cuda(), targets.cuda()
        
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
Mixed precision training can provide 2-3x speedup on modern GPUs with Tensor Cores (V100, A100, RTX series) with minimal accuracy impact.

Gradient Clipping

Prevent exploding gradients by clipping them:
max_grad_norm = 1.0

for inputs, targets in train_loader:
    optimizer.zero_grad()
    
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    
    # Clip gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    
    optimizer.step()

Gradient Accumulation

Simulate larger batch sizes when GPU memory is limited:
accumulation_steps = 4  # Effective batch size = batch_size * accumulation_steps

optimizer.zero_grad()
for i, (inputs, targets) in enumerate(train_loader):
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # Normalize loss by accumulation steps
    loss = loss / accumulation_steps
    loss.backward()
    
    # Update weights every accumulation_steps
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Learning Rate Warmup

Gradually increase learning rate at the start of training:
def warmup_lr_scheduler(optimizer, warmup_epochs, base_lr):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        return 1.0
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Usage
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = warmup_lr_scheduler(optimizer, warmup_epochs=5, base_lr=0.001)

for epoch in range(num_epochs):
    # Training loop
    train_epoch()
    
    # Step scheduler
    scheduler.step()

Monitoring Training

Using TensorBoard

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/experiment_1')

for epoch in range(num_epochs):
    # Training
    for i, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # Log training loss
        global_step = epoch * len(train_loader) + i
        writer.add_scalar('Loss/train', loss.item(), global_step)
    
    # Validation
    val_loss, val_acc = validate()
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)
    
    # Log learning rate
    writer.add_scalar('Learning_rate', optimizer.param_groups[0]['lr'], epoch)

writer.close()

Early Stopping

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# Usage
early_stopping = EarlyStopping(patience=5)

for epoch in range(num_epochs):
    train_loss = train_epoch()
    val_loss, val_acc = validate()
    
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break
Always validate your model on a separate validation set. Don’t use training loss alone to evaluate model performance.

Best Practices

  • Always call model.train() before training and model.eval() before validation
  • Use torch.no_grad() during validation to save memory
  • Monitor both training and validation metrics to detect overfitting
  • Save checkpoints regularly
  • Use gradient clipping for RNNs and deep networks
  • Start with a simple baseline before adding complexity
  • Batch size: 32, 64, 128 (larger is faster but needs more memory)
  • Learning rate: 1e-3, 1e-4 (use learning rate finder)
  • Number of epochs: Monitor validation loss to determine
  • Optimizer: Adam is a good default, SGD with momentum for fine-tuning
  • Weight decay: 1e-4, 1e-5 for regularization

Next Steps