Overview
PyTorch provides flexible methods to save and load models. This guide covers best practices for saving model weights, complete checkpoints, and managing model persistence.Basic Model Saving
Save and Load Model Weights
import torch
import torch.nn as nn
# Define model
model = YourModel()
# Train model
train(model)
# Save model weights
torch.save(model.state_dict(), 'model_weights.pth')
# Load model weights
model = YourModel() # Create model instance first
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # Set to evaluation mode
Important: Always instantiate your model with the same architecture before loading weights. The model class definition must match the saved weights.
Save Entire Model (Not Recommended)
# Save entire model (architecture + weights)
torch.save(model, 'model_complete.pth')
# Load entire model
model = torch.load('model_complete.pth')
model.eval()
Saving the entire model is not recommended because:
- It’s less flexible (tied to specific directory structure)
- May break with PyTorch version changes
- Harder to debug serialization issues
- Larger file size
state_dict() instead.Complete Checkpoint Saving
Save everything needed to resume training:import torch
# Save checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'val_accuracy': val_accuracy,
'scheduler_state_dict': scheduler.state_dict(),
}
torch.save(checkpoint, 'checkpoint.pth')
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model = YourModel()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint['loss']
# Resume training
model.train()
Advanced Checkpoint Management
Checkpoint Manager Class
import os
import torch
from pathlib import Path
class CheckpointManager:
def __init__(self, model, optimizer, save_dir, max_checkpoints=5):
self.model = model
self.optimizer = optimizer
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
self.checkpoints = []
def save(self, epoch, metrics, is_best=False):
"""Save checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'metrics': metrics,
}
# Save regular checkpoint
checkpoint_path = self.save_dir / f'checkpoint_epoch_{epoch}.pth'
torch.save(checkpoint, checkpoint_path)
self.checkpoints.append(checkpoint_path)
# Save best model
if is_best:
best_path = self.save_dir / 'best_model.pth'
torch.save(checkpoint, best_path)
print(f"Saved best model at epoch {epoch}")
# Remove old checkpoints
if len(self.checkpoints) > self.max_checkpoints:
old_checkpoint = self.checkpoints.pop(0)
if old_checkpoint.exists():
old_checkpoint.unlink()
def load_latest(self):
"""Load most recent checkpoint."""
if not self.checkpoints:
checkpoints = sorted(
self.save_dir.glob('checkpoint_epoch_*.pth'),
key=lambda x: int(x.stem.split('_')[-1])
)
if checkpoints:
return self.load(checkpoints[-1])
return None
return self.load(self.checkpoints[-1])
def load_best(self):
"""Load best model."""
best_path = self.save_dir / 'best_model.pth'
if best_path.exists():
return self.load(best_path)
return None
def load(self, checkpoint_path):
"""Load checkpoint from path."""
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint
# Usage
manager = CheckpointManager(model, optimizer, 'checkpoints', max_checkpoints=3)
for epoch in range(num_epochs):
train_loss = train_epoch()
val_loss, val_acc = validate()
metrics = {
'train_loss': train_loss,
'val_loss': val_loss,
'val_accuracy': val_acc
}
# Save checkpoint
is_best = val_loss < best_val_loss
manager.save(epoch, metrics, is_best=is_best)
# Resume training
checkpoint = manager.load_latest()
if checkpoint:
start_epoch = checkpoint['epoch'] + 1
print(f"Resuming from epoch {start_epoch}")
Save Best Model During Training
import torch
import numpy as np
best_val_loss = float('inf')
patience = 10
patience_counter = 0
for epoch in range(num_epochs):
# Training
train_loss = train_epoch(model, train_loader, optimizer, criterion)
# Validation
val_loss, val_acc = validate(model, val_loader, criterion)
print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, '
f'Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
'val_accuracy': val_acc,
}, 'best_model.pth')
print(f'Saved best model with val_loss={val_loss:.4f}')
else:
patience_counter += 1
# Early stopping
if patience_counter >= patience:
print(f'Early stopping at epoch {epoch+1}')
break
Device Management
Save on GPU, Load on CPU
# Save model trained on GPU
model = model.cuda()
torch.save(model.state_dict(), 'model.pth')
# Load on CPU
model = YourModel()
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()
Save on CPU, Load on GPU
# Save model
torch.save(model.state_dict(), 'model.pth')
# Load on specific GPU
device = torch.device('cuda:0')
model = YourModel()
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
General Device Handling
# Save
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
torch.save(model.state_dict(), 'model.pth')
# Load with automatic device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YourModel()
model.load_state_dict(torch.load('model.pth', map_location=device))
model.to(device)
model.eval()
Always use
map_location when loading models to ensure they load correctly regardless of the device they were saved on.DataParallel and DistributedDataParallel
Saving DataParallel Models
import torch.nn as nn
# Wrap model in DataParallel
model = nn.DataParallel(model)
# Save: remove 'module.' prefix
torch.save(model.module.state_dict(), 'model.pth')
# Or save the full DataParallel state
torch.save(model.state_dict(), 'model_parallel.pth')
Loading into DataParallel
# Load weights saved from DataParallel model
model = YourModel()
state_dict = torch.load('model_parallel.pth')
# Remove 'module.' prefix from keys
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k # remove 'module.'
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
Partial Loading
Transfer Learning
# Load pretrained weights (may not match exactly)
model = YourModel()
pretrained_dict = torch.load('pretrained.pth')
model_dict = model.state_dict()
# Filter out unnecessary keys
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
}
# Update current model dict
model_dict.update(pretrained_dict)
# Load updated weights
model.load_state_dict(model_dict)
Strict Loading Control
# Load with mismatched keys
model.load_state_dict(torch.load('model.pth'), strict=False)
# Get missing and unexpected keys
missing_keys, unexpected_keys = model.load_state_dict(
torch.load('model.pth'),
strict=False
)
if missing_keys:
print(f"Missing keys: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys: {unexpected_keys}")
Saving for Production
TorchScript Export
import torch
# Method 1: Tracing
model = YourModel()
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
# Save traced model
traced_model.save('model_traced.pt')
# Load and use
loaded_model = torch.jit.load('model_traced.pt')
loaded_model.eval()
with torch.no_grad():
output = loaded_model(example_input)
# Method 2: Scripting (for models with control flow)
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
ONNX Export
import torch.onnx
model = YourModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
'model.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
Cloud Storage Integration
Save to S3 (AWS)
import torch
import boto3
import io
def save_to_s3(model, bucket, key):
"""Save model to S3."""
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
s3 = boto3.client('s3')
s3.put_object(Bucket=bucket, Key=key, Body=buffer.getvalue())
def load_from_s3(bucket, key):
"""Load model from S3."""
s3 = boto3.client('s3')
response = s3.get_object(Bucket=bucket, Key=key)
buffer = io.BytesIO(response['Body'].read())
state_dict = torch.load(buffer)
return state_dict
# Usage
save_to_s3(model, 'my-bucket', 'models/model.pth')
state_dict = load_from_s3('my-bucket', 'models/model.pth')
model.load_state_dict(state_dict)
Save to Google Cloud Storage
from google.cloud import storage
import torch
import io
def save_to_gcs(model, bucket_name, blob_name):
"""Save model to Google Cloud Storage."""
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
buffer.seek(0)
blob.upload_from_file(buffer)
def load_from_gcs(bucket_name, blob_name):
"""Load model from Google Cloud Storage."""
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name)
buffer = io.BytesIO()
blob.download_to_file(buffer)
buffer.seek(0)
return torch.load(buffer)
Model Versioning
import torch
from datetime import datetime
import json
class ModelVersion:
def __init__(self, model, metadata=None):
self.model = model
self.metadata = metadata or {}
self.metadata['saved_at'] = datetime.now().isoformat()
def save(self, path):
"""Save model with metadata."""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'metadata': self.metadata
}
torch.save(checkpoint, path)
# Save metadata separately as JSON
metadata_path = path.replace('.pth', '_metadata.json')
with open(metadata_path, 'w') as f:
json.dump(self.metadata, f, indent=2)
@classmethod
def load(cls, path, model_class):
"""Load model with metadata."""
checkpoint = torch.load(path)
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
return cls(model, checkpoint.get('metadata', {}))
# Usage
metadata = {
'architecture': 'ResNet50',
'dataset': 'ImageNet',
'accuracy': 76.5,
'parameters': sum(p.numel() for p in model.parameters()),
'training_epochs': 100,
'hyperparameters': {
'lr': 0.1,
'batch_size': 256,
'optimizer': 'SGD'
}
}
version = ModelVersion(model, metadata)
version.save('models/resnet50_v1.0.0.pth')
# Load
loaded_version = ModelVersion.load('models/resnet50_v1.0.0.pth', YourModel)
print(f"Model saved at: {loaded_version.metadata['saved_at']}")
print(f"Accuracy: {loaded_version.metadata['accuracy']}%")
Best Practices
Saving guidelines
Saving guidelines
- Always save
state_dict(), not the entire model - Include epoch, optimizer state, and metrics in checkpoints
- Save best model based on validation metrics
- Keep multiple checkpoints (last N epochs)
- Use meaningful checkpoint names with timestamps or metrics
- Save hyperparameters and metadata with models
Loading guidelines
Loading guidelines
- Always use
map_locationfor device compatibility - Set model to
eval()mode after loading for inference - Handle DataParallel models by removing ‘module.’ prefix
- Use
strict=Falsecarefully and check missing/unexpected keys - Verify model architecture matches saved weights
Storage tips
Storage tips
- Compress checkpoints:
torch.save(obj, path, _use_new_zipfile_serialization=True) - Clean up old checkpoints to save disk space
- Use version control for model artifacts (DVC, MLflow)
- Store models in cloud storage for team access
- Keep separate directories for experiments
Common Patterns
# Training with checkpointing
def train_with_checkpoints(
model,
train_loader,
val_loader,
num_epochs,
checkpoint_dir='checkpoints'
):
os.makedirs(checkpoint_dir, exist_ok=True)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
best_val_loss = float('inf')
for epoch in range(num_epochs):
# Train
train_loss = train_epoch(model, train_loader, optimizer, criterion)
# Validate
val_loss, val_acc = validate(model, val_loader, criterion)
# Save checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'val_accuracy': val_acc,
}
# Save latest
torch.save(checkpoint, f'{checkpoint_dir}/latest.pth')
# Save best
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(checkpoint, f'{checkpoint_dir}/best.pth')
# Save periodic checkpoint
if (epoch + 1) % 10 == 0:
torch.save(checkpoint, f'{checkpoint_dir}/epoch_{epoch+1}.pth')
Next Steps
- Learn about building models
- Understand training loops
- Explore model optimization