Skip to main content

Overview

PyTorch provides flexible methods to save and load models. This guide covers best practices for saving model weights, complete checkpoints, and managing model persistence.

Basic Model Saving

Save and Load Model Weights

import torch
import torch.nn as nn

# Define model
model = YourModel()

# Train model
train(model)

# Save model weights
torch.save(model.state_dict(), 'model_weights.pth')

# Load model weights
model = YourModel()  # Create model instance first
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # Set to evaluation mode
1

Save state dictionary

Use model.state_dict() to get a dictionary of all parameters
2

Load state dictionary

Create a model instance with the same architecture, then load the weights
3

Set evaluation mode

Call model.eval() for inference
Important: Always instantiate your model with the same architecture before loading weights. The model class definition must match the saved weights.
# Save entire model (architecture + weights)
torch.save(model, 'model_complete.pth')

# Load entire model
model = torch.load('model_complete.pth')
model.eval()
Saving the entire model is not recommended because:
  • It’s less flexible (tied to specific directory structure)
  • May break with PyTorch version changes
  • Harder to debug serialization issues
  • Larger file size
Always prefer saving state_dict() instead.

Complete Checkpoint Saving

Save everything needed to resume training:
import torch

# Save checkpoint
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'val_accuracy': val_accuracy,
    'scheduler_state_dict': scheduler.state_dict(),
}

torch.save(checkpoint, 'checkpoint.pth')
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')

model = YourModel()
model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint['loss']

# Resume training
model.train()

Advanced Checkpoint Management

Checkpoint Manager Class

import os
import torch
from pathlib import Path

class CheckpointManager:
    def __init__(self, model, optimizer, save_dir, max_checkpoints=5):
        self.model = model
        self.optimizer = optimizer
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.max_checkpoints = max_checkpoints
        self.checkpoints = []
    
    def save(self, epoch, metrics, is_best=False):
        """Save checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics': metrics,
        }
        
        # Save regular checkpoint
        checkpoint_path = self.save_dir / f'checkpoint_epoch_{epoch}.pth'
        torch.save(checkpoint, checkpoint_path)
        self.checkpoints.append(checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = self.save_dir / 'best_model.pth'
            torch.save(checkpoint, best_path)
            print(f"Saved best model at epoch {epoch}")
        
        # Remove old checkpoints
        if len(self.checkpoints) > self.max_checkpoints:
            old_checkpoint = self.checkpoints.pop(0)
            if old_checkpoint.exists():
                old_checkpoint.unlink()
    
    def load_latest(self):
        """Load most recent checkpoint."""
        if not self.checkpoints:
            checkpoints = sorted(
                self.save_dir.glob('checkpoint_epoch_*.pth'),
                key=lambda x: int(x.stem.split('_')[-1])
            )
            if checkpoints:
                return self.load(checkpoints[-1])
            return None
        
        return self.load(self.checkpoints[-1])
    
    def load_best(self):
        """Load best model."""
        best_path = self.save_dir / 'best_model.pth'
        if best_path.exists():
            return self.load(best_path)
        return None
    
    def load(self, checkpoint_path):
        """Load checkpoint from path."""
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint

# Usage
manager = CheckpointManager(model, optimizer, 'checkpoints', max_checkpoints=3)

for epoch in range(num_epochs):
    train_loss = train_epoch()
    val_loss, val_acc = validate()
    
    metrics = {
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_accuracy': val_acc
    }
    
    # Save checkpoint
    is_best = val_loss < best_val_loss
    manager.save(epoch, metrics, is_best=is_best)

# Resume training
checkpoint = manager.load_latest()
if checkpoint:
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")

Save Best Model During Training

import torch
import numpy as np

best_val_loss = float('inf')
patience = 10
patience_counter = 0

for epoch in range(num_epochs):
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    
    # Validation
    val_loss, val_acc = validate(model, val_loader, criterion)
    
    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, '
          f'Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%')
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_accuracy': val_acc,
        }, 'best_model.pth')
        
        print(f'Saved best model with val_loss={val_loss:.4f}')
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f'Early stopping at epoch {epoch+1}')
        break

Device Management

Save on GPU, Load on CPU

# Save model trained on GPU
model = model.cuda()
torch.save(model.state_dict(), 'model.pth')

# Load on CPU
model = YourModel()
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()

Save on CPU, Load on GPU

# Save model
torch.save(model.state_dict(), 'model.pth')

# Load on specific GPU
device = torch.device('cuda:0')
model = YourModel()
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)

General Device Handling

# Save
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
torch.save(model.state_dict(), 'model.pth')

# Load with automatic device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YourModel()
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
model.eval()
Always use map_location when loading models to ensure they load correctly regardless of the device they were saved on.

DataParallel and DistributedDataParallel

Saving DataParallel Models

import torch.nn as nn

# Wrap model in DataParallel
model = nn.DataParallel(model)

# Save: remove 'module.' prefix
torch.save(model.module.state_dict(), 'model.pth')

# Or save the full DataParallel state
torch.save(model.state_dict(), 'model_parallel.pth')

Loading into DataParallel

# Load weights saved from DataParallel model
model = YourModel()
state_dict = torch.load('model_parallel.pth')

# Remove 'module.' prefix from keys
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k  # remove 'module.'
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

Partial Loading

Transfer Learning

# Load pretrained weights (may not match exactly)
model = YourModel()
pretrained_dict = torch.load('pretrained.pth')
model_dict = model.state_dict()

# Filter out unnecessary keys
pretrained_dict = {
    k: v for k, v in pretrained_dict.items() 
    if k in model_dict and v.shape == model_dict[k].shape
}

# Update current model dict
model_dict.update(pretrained_dict)

# Load updated weights
model.load_state_dict(model_dict)

Strict Loading Control

# Load with mismatched keys
model.load_state_dict(torch.load('model.pth'), strict=False)

# Get missing and unexpected keys
missing_keys, unexpected_keys = model.load_state_dict(
    torch.load('model.pth'),
    strict=False
)

if missing_keys:
    print(f"Missing keys: {missing_keys}")
if unexpected_keys:
    print(f"Unexpected keys: {unexpected_keys}")

Saving for Production

TorchScript Export

import torch

# Method 1: Tracing
model = YourModel()
model.eval()

example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

# Save traced model
traced_model.save('model_traced.pt')

# Load and use
loaded_model = torch.jit.load('model_traced.pt')
loaded_model.eval()
with torch.no_grad():
    output = loaded_model(example_input)
# Method 2: Scripting (for models with control flow)
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')

ONNX Export

import torch.onnx

model = YourModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

Cloud Storage Integration

Save to S3 (AWS)

import torch
import boto3
import io

def save_to_s3(model, bucket, key):
    """Save model to S3."""
    buffer = io.BytesIO()
    torch.save(model.state_dict(), buffer)
    
    s3 = boto3.client('s3')
    s3.put_object(Bucket=bucket, Key=key, Body=buffer.getvalue())

def load_from_s3(bucket, key):
    """Load model from S3."""
    s3 = boto3.client('s3')
    response = s3.get_object(Bucket=bucket, Key=key)
    
    buffer = io.BytesIO(response['Body'].read())
    state_dict = torch.load(buffer)
    
    return state_dict

# Usage
save_to_s3(model, 'my-bucket', 'models/model.pth')
state_dict = load_from_s3('my-bucket', 'models/model.pth')
model.load_state_dict(state_dict)

Save to Google Cloud Storage

from google.cloud import storage
import torch
import io

def save_to_gcs(model, bucket_name, blob_name):
    """Save model to Google Cloud Storage."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(blob_name)
    
    buffer = io.BytesIO()
    torch.save(model.state_dict(), buffer)
    buffer.seek(0)
    
    blob.upload_from_file(buffer)

def load_from_gcs(bucket_name, blob_name):
    """Load model from Google Cloud Storage."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(blob_name)
    
    buffer = io.BytesIO()
    blob.download_to_file(buffer)
    buffer.seek(0)
    
    return torch.load(buffer)

Model Versioning

import torch
from datetime import datetime
import json

class ModelVersion:
    def __init__(self, model, metadata=None):
        self.model = model
        self.metadata = metadata or {}
        self.metadata['saved_at'] = datetime.now().isoformat()
    
    def save(self, path):
        """Save model with metadata."""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'metadata': self.metadata
        }
        torch.save(checkpoint, path)
        
        # Save metadata separately as JSON
        metadata_path = path.replace('.pth', '_metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(self.metadata, f, indent=2)
    
    @classmethod
    def load(cls, path, model_class):
        """Load model with metadata."""
        checkpoint = torch.load(path)
        
        model = model_class()
        model.load_state_dict(checkpoint['model_state_dict'])
        
        return cls(model, checkpoint.get('metadata', {}))

# Usage
metadata = {
    'architecture': 'ResNet50',
    'dataset': 'ImageNet',
    'accuracy': 76.5,
    'parameters': sum(p.numel() for p in model.parameters()),
    'training_epochs': 100,
    'hyperparameters': {
        'lr': 0.1,
        'batch_size': 256,
        'optimizer': 'SGD'
    }
}

version = ModelVersion(model, metadata)
version.save('models/resnet50_v1.0.0.pth')

# Load
loaded_version = ModelVersion.load('models/resnet50_v1.0.0.pth', YourModel)
print(f"Model saved at: {loaded_version.metadata['saved_at']}")
print(f"Accuracy: {loaded_version.metadata['accuracy']}%")

Best Practices

  • Always save state_dict(), not the entire model
  • Include epoch, optimizer state, and metrics in checkpoints
  • Save best model based on validation metrics
  • Keep multiple checkpoints (last N epochs)
  • Use meaningful checkpoint names with timestamps or metrics
  • Save hyperparameters and metadata with models
  • Always use map_location for device compatibility
  • Set model to eval() mode after loading for inference
  • Handle DataParallel models by removing ‘module.’ prefix
  • Use strict=False carefully and check missing/unexpected keys
  • Verify model architecture matches saved weights
  • Compress checkpoints: torch.save(obj, path, _use_new_zipfile_serialization=True)
  • Clean up old checkpoints to save disk space
  • Use version control for model artifacts (DVC, MLflow)
  • Store models in cloud storage for team access
  • Keep separate directories for experiments

Common Patterns

# Training with checkpointing
def train_with_checkpoints(
    model,
    train_loader,
    val_loader,
    num_epochs,
    checkpoint_dir='checkpoints'
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, criterion)
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion)
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
        }
        
        # Save latest
        torch.save(checkpoint, f'{checkpoint_dir}/latest.pth')
        
        # Save best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, f'{checkpoint_dir}/best.pth')
        
        # Save periodic checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save(checkpoint, f'{checkpoint_dir}/epoch_{epoch+1}.pth')

Next Steps