Data Loading in PyTorch
PyTorch’storch.utils.data module provides powerful tools for loading and preprocessing data efficiently. The key components are Dataset for data storage/access and DataLoader for batch loading with multiprocessing support.
Dataset Classes
TheDataset class is an abstract class representing a dataset. There are two types:
- Map-style datasets: Implement
__getitem__and__len__ - Iterable-style datasets: Implement
__iter__
Use map-style datasets when you can access data by index (most common). Use iterable-style datasets for streaming data or when random access is expensive.
Map-Style Dataset
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
"""Custom Dataset for loading data."""
def __init__(self, data, labels, transform=None):
"""
Args:
data: Input data (numpy array or tensor)
labels: Labels for the data
transform: Optional transform to apply
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
"""Returns the total number of samples."""
return len(self.data)
def __getitem__(self, idx):
"""Returns one sample at the given index."""
sample = self.data[idx]
label = self.labels[idx]
# Apply transformations if any
if self.transform:
sample = self.transform(sample)
return sample, label
# Usage
import numpy as np
data = np.random.randn(1000, 3, 32, 32).astype(np.float32)
labels = np.random.randint(0, 10, 1000)
dataset = CustomDataset(data, labels)
print(f"Dataset size: {len(dataset)}")
# Access individual samples
sample, label = dataset[0]
print(f"Sample shape: {sample.shape}, Label: {label}")
Iterable-Style Dataset
Useful for streaming data or large datasets that don’t fit in memory:import torch
import math
from torch.utils.data import IterableDataset
class StreamingDataset(IterableDataset):
"""Dataset that streams data (e.g., from a file or network)."""
def __init__(self, start, end):
super().__init__()
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single-process loading
iter_start = self.start
iter_end = self.end
else:
# Multi-process loading: split workload
per_worker = int(math.ceil((self.end - self.start) / worker_info.num_workers))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
# Generate data
for i in range(iter_start, iter_end):
# Simulate reading from stream
yield torch.randn(3, 32, 32), i % 10
# Usage
dataset = StreamingDataset(start=0, end=1000)
for i, (data, label) in enumerate(dataset):
if i >= 5: # Print first 5
break
print(f"Sample {i}: data shape {data.shape}, label {label}")
DataLoader
TheDataLoader combines a dataset with batching, shuffling, and parallel data loading:
Basic Usage
import torch
from torch.utils.data import DataLoader, TensorDataset
# Create dataset
data = torch.randn(1000, 3, 32, 32)
labels = torch.randint(0, 10, (1000,))
dataset = TensorDataset(data, labels)
# Create DataLoader
data_loader = DataLoader(
dataset,
batch_size=32, # Batch size
shuffle=True, # Shuffle data each epoch
num_workers=4, # Parallel data loading
pin_memory=True # Faster GPU transfer
)
# Iterate through batches
for batch_idx, (batch_data, batch_labels) in enumerate(data_loader):
print(f"Batch {batch_idx}: data shape {batch_data.shape}, labels shape {batch_labels.shape}")
# batch_data: (32, 3, 32, 32)
# batch_labels: (32,)
DataLoader Parameters
Key DataLoader Parameters
Key DataLoader Parameters
- batch_size: Number of samples per batch (default: 1)
- shuffle: Shuffle data at every epoch (default: False)
- num_workers: Number of subprocesses for data loading (default: 0)
- pin_memory: Copy tensors to CUDA pinned memory for faster GPU transfer (default: False)
- drop_last: Drop last incomplete batch (default: False)
- collate_fn: Custom function to merge samples into batch
- sampler: Custom sampling strategy (mutually exclusive with shuffle)
from torch.utils.data import DataLoader
loader = DataLoader(
dataset,
batch_size=32, # Samples per batch
shuffle=True, # Shuffle data
num_workers=4, # Parallel workers
pin_memory=True, # Faster GPU transfer
drop_last=True, # Drop incomplete last batch
persistent_workers=True # Keep workers alive between epochs
)
Custom Samplers
Samplers define the strategy to draw samples from the dataset:- Sequential Sampler
- Random Sampler
- Weighted Sampler
- Custom Sampler
from torch.utils.data import DataLoader, SequentialSampler
# Sample in order (0, 1, 2, ...)
sampler = SequentialSampler(dataset)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# Note: Can't use shuffle=True with custom sampler
from torch.utils.data import DataLoader, RandomSampler
# Sample randomly (with replacement option)
sampler = RandomSampler(dataset, replacement=False)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
# Sample with weights (useful for imbalanced datasets)
# Assume 1000 samples with imbalanced classes
class_counts = [800, 100, 100] # 3 classes
weights = [1.0/800, 1.0/100, 1.0/100] # Inverse frequency
# Assign weight to each sample based on its class
sample_weights = torch.zeros(1000)
labels = torch.randint(0, 3, (1000,))
for i, label in enumerate(labels):
sample_weights[i] = weights[label]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
import torch
from torch.utils.data import Sampler, DataLoader
class CustomBatchSampler(Sampler):
"""Sample batches with specific properties."""
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
def __iter__(self):
# Custom sampling logic
indices = torch.randperm(len(self.dataset)).tolist()
batches = [indices[i:i+self.batch_size]
for i in range(0, len(indices), self.batch_size)]
for batch in batches:
yield batch
def __len__(self):
return (len(self.dataset) + self.batch_size - 1) // self.batch_size
sampler = CustomBatchSampler(dataset, batch_size=32)
loader = DataLoader(dataset, batch_sampler=sampler)
Data Transformations
Using Transforms
Transforms are commonly used with image datasets:import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
class ImageDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# Define transforms
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Create dataset with transforms
images = np.random.randint(0, 255, (1000, 3, 32, 32), dtype=np.uint8)
labels = np.random.randint(0, 10, 1000)
dataset = ImageDataset(images, labels, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
Separate Train/Val Transforms
import torchvision.transforms as transforms
# Training: include data augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
train_dataset = ImageDataset(train_images, train_labels,
transform=train_transform)
Custom Collate Function
The collate function merges samples into a batch. Custom collate functions handle special cases:import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
def collate_variable_length(batch):
"""
Custom collate for variable-length sequences.
Args:
batch: List of (sequence, label) tuples
Returns:
Padded sequences and labels
"""
# Separate sequences and labels
sequences, labels = zip(*batch)
# Pad sequences to same length
sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
# Get original lengths
lengths = torch.tensor([len(seq) for seq in sequences])
# Stack labels
labels = torch.tensor(labels)
return sequences_padded, lengths, labels
# Dataset with variable-length sequences
class SequenceDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Variable length sequence (10 to 50 timesteps)
length = torch.randint(10, 50, (1,)).item()
sequence = torch.randn(length, 128) # (seq_len, features)
label = torch.randint(0, 10, (1,)).item()
return sequence, label
# Use custom collate function
dataset = SequenceDataset(1000)
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
collate_fn=collate_variable_length
)
# Iterate
for sequences, lengths, labels in loader:
print(f"Sequences: {sequences.shape}") # (32, max_len, 128)
print(f"Lengths: {lengths}") # (32,)
print(f"Labels: {labels.shape}") # (32,)
break
Built-in Datasets
PyTorch provides built-in datasets (requirestorchvision):
Built-in datasets automatically download data on first use and provide standard train/test splits.
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# CIFAR-10 dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Training set
train_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# Test set
test_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
# Classes
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
Complete Training Pipeline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
# 1. Define Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 2. Create datasets
data = torch.randn(10000, 784)
labels = torch.randint(0, 10, (10000,))
full_dataset = MyDataset(data, labels)
# Split into train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# 3. Create DataLoaders
train_loader = DataLoader(
train_dataset,
batch_size=128,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=256,
shuffle=False,
num_workers=4,
pin_memory=True
)
# 4. Define model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 10)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 5. Training loop
num_epochs = 10
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
pred = output.argmax(dim=1)
train_correct += pred.eq(target).sum().item()
train_loss /= len(train_loader)
train_acc = 100. * train_correct / len(train_dataset)
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
pred = output.argmax(dim=1)
val_correct += pred.eq(target).sum().item()
val_loss /= len(val_loader)
val_acc = 100. * val_correct / len(val_dataset)
print(f"Epoch {epoch+1}/{num_epochs}:")
print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
print(f" Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
Performance Optimization
- Use num_workers > 0: Parallel data loading significantly speeds up training
- Enable pin_memory: Faster data transfer to GPU when using CUDA
- Optimize batch size: Larger batches improve GPU utilization (up to memory limits)
- Persistent workers: Set
persistent_workers=Trueto avoid worker restart overhead - Prefetch factor: Adjust
prefetch_factorto control how many batches to load ahead
Multi-Process Best Practices
from torch.utils.data import DataLoader
# Optimal settings for fast training
loader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=4, # Typically 4-8 workers
pin_memory=True, # Enable for GPU
persistent_workers=True, # Keep workers alive
prefetch_factor=2 # Batches to prefetch per worker
)
Common Patterns
Concatenating Datasets
from torch.utils.data import ConcatDataset, DataLoader
# Combine multiple datasets
dataset1 = MyDataset(data1, labels1)
dataset2 = MyDataset(data2, labels2)
combined_dataset = ConcatDataset([dataset1, dataset2])
loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)
print(f"Combined size: {len(combined_dataset)}")
Subset of Dataset
from torch.utils.data import Subset, DataLoader
import numpy as np
# Use only a subset of data
indices = np.random.choice(len(dataset), size=1000, replace=False)
subset = Subset(dataset, indices)
loader = DataLoader(subset, batch_size=32, shuffle=True)
print(f"Subset size: {len(subset)}")
Next Steps
Neural Networks
Build models to train on your data
Tensors
Understand PyTorch’s fundamental data structure