Skip to main content

Data Loading in PyTorch

PyTorch’s torch.utils.data module provides powerful tools for loading and preprocessing data efficiently. The key components are Dataset for data storage/access and DataLoader for batch loading with multiprocessing support.

Dataset Classes

The Dataset class is an abstract class representing a dataset. There are two types:
  1. Map-style datasets: Implement __getitem__ and __len__
  2. Iterable-style datasets: Implement __iter__
Use map-style datasets when you can access data by index (most common). Use iterable-style datasets for streaming data or when random access is expensive.

Map-Style Dataset

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    """Custom Dataset for loading data."""
    
    def __init__(self, data, labels, transform=None):
        """
        Args:
            data: Input data (numpy array or tensor)
            labels: Labels for the data
            transform: Optional transform to apply
        """
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        """Returns the total number of samples."""
        return len(self.data)
    
    def __getitem__(self, idx):
        """Returns one sample at the given index."""
        sample = self.data[idx]
        label = self.labels[idx]
        
        # Apply transformations if any
        if self.transform:
            sample = self.transform(sample)
        
        return sample, label

# Usage
import numpy as np

data = np.random.randn(1000, 3, 32, 32).astype(np.float32)
labels = np.random.randint(0, 10, 1000)

dataset = CustomDataset(data, labels)
print(f"Dataset size: {len(dataset)}")

# Access individual samples
sample, label = dataset[0]
print(f"Sample shape: {sample.shape}, Label: {label}")

Iterable-Style Dataset

Useful for streaming data or large datasets that don’t fit in memory:
import torch
import math
from torch.utils.data import IterableDataset

class StreamingDataset(IterableDataset):
    """Dataset that streams data (e.g., from a file or network)."""
    
    def __init__(self, start, end):
        super().__init__()
        self.start = start
        self.end = end
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        
        if worker_info is None:
            # Single-process loading
            iter_start = self.start
            iter_end = self.end
        else:
            # Multi-process loading: split workload
            per_worker = int(math.ceil((self.end - self.start) / worker_info.num_workers))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        
        # Generate data
        for i in range(iter_start, iter_end):
            # Simulate reading from stream
            yield torch.randn(3, 32, 32), i % 10

# Usage
dataset = StreamingDataset(start=0, end=1000)
for i, (data, label) in enumerate(dataset):
    if i >= 5:  # Print first 5
        break
    print(f"Sample {i}: data shape {data.shape}, label {label}")

DataLoader

The DataLoader combines a dataset with batching, shuffling, and parallel data loading:

Basic Usage

import torch
from torch.utils.data import DataLoader, TensorDataset

# Create dataset
data = torch.randn(1000, 3, 32, 32)
labels = torch.randint(0, 10, (1000,))
dataset = TensorDataset(data, labels)

# Create DataLoader
data_loader = DataLoader(
    dataset,
    batch_size=32,      # Batch size
    shuffle=True,       # Shuffle data each epoch
    num_workers=4,      # Parallel data loading
    pin_memory=True     # Faster GPU transfer
)

# Iterate through batches
for batch_idx, (batch_data, batch_labels) in enumerate(data_loader):
    print(f"Batch {batch_idx}: data shape {batch_data.shape}, labels shape {batch_labels.shape}")
    # batch_data: (32, 3, 32, 32)
    # batch_labels: (32,)

DataLoader Parameters

  • batch_size: Number of samples per batch (default: 1)
  • shuffle: Shuffle data at every epoch (default: False)
  • num_workers: Number of subprocesses for data loading (default: 0)
  • pin_memory: Copy tensors to CUDA pinned memory for faster GPU transfer (default: False)
  • drop_last: Drop last incomplete batch (default: False)
  • collate_fn: Custom function to merge samples into batch
  • sampler: Custom sampling strategy (mutually exclusive with shuffle)
from torch.utils.data import DataLoader

loader = DataLoader(
    dataset,
    batch_size=32,           # Samples per batch
    shuffle=True,            # Shuffle data
    num_workers=4,           # Parallel workers
    pin_memory=True,         # Faster GPU transfer
    drop_last=True,          # Drop incomplete last batch
    persistent_workers=True  # Keep workers alive between epochs
)

Custom Samplers

Samplers define the strategy to draw samples from the dataset:
from torch.utils.data import DataLoader, SequentialSampler

# Sample in order (0, 1, 2, ...)
sampler = SequentialSampler(dataset)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)

# Note: Can't use shuffle=True with custom sampler

Data Transformations

Using Transforms

Transforms are commonly used with image datasets:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np

class ImageDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Create dataset with transforms
images = np.random.randint(0, 255, (1000, 3, 32, 32), dtype=np.uint8)
labels = np.random.randint(0, 10, 1000)

dataset = ImageDataset(images, labels, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

Separate Train/Val Transforms

import torchvision.transforms as transforms

# Training: include data augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
])

train_dataset = ImageDataset(train_images, train_labels, 
                            transform=train_transform)

Custom Collate Function

The collate function merges samples into a batch. Custom collate functions handle special cases:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_variable_length(batch):
    """
    Custom collate for variable-length sequences.
    
    Args:
        batch: List of (sequence, label) tuples
    
    Returns:
        Padded sequences and labels
    """
    # Separate sequences and labels
    sequences, labels = zip(*batch)
    
    # Pad sequences to same length
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    
    # Get original lengths
    lengths = torch.tensor([len(seq) for seq in sequences])
    
    # Stack labels
    labels = torch.tensor(labels)
    
    return sequences_padded, lengths, labels

# Dataset with variable-length sequences
class SequenceDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Variable length sequence (10 to 50 timesteps)
        length = torch.randint(10, 50, (1,)).item()
        sequence = torch.randn(length, 128)  # (seq_len, features)
        label = torch.randint(0, 10, (1,)).item()
        return sequence, label

# Use custom collate function
dataset = SequenceDataset(1000)
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_variable_length
)

# Iterate
for sequences, lengths, labels in loader:
    print(f"Sequences: {sequences.shape}")  # (32, max_len, 128)
    print(f"Lengths: {lengths}")              # (32,)
    print(f"Labels: {labels.shape}")          # (32,)
    break

Built-in Datasets

PyTorch provides built-in datasets (requires torchvision):
Built-in datasets automatically download data on first use and provide standard train/test splits.
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Training set
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# Test set
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Classes
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Complete Training Pipeline

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

# 1. Define Dataset
class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# 2. Create datasets
data = torch.randn(10000, 784)
labels = torch.randint(0, 10, (10000,))
full_dataset = MyDataset(data, labels)

# Split into train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# 3. Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# 4. Define model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 10)
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 5. Training loop
num_epochs = 10

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1)
        train_correct += pred.eq(target).sum().item()
    
    train_loss /= len(train_loader)
    train_acc = 100. * train_correct / len(train_dataset)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            val_loss += loss.item()
            pred = output.argmax(dim=1)
            val_correct += pred.eq(target).sum().item()
    
    val_loss /= len(val_loader)
    val_acc = 100. * val_correct / len(val_dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

Performance Optimization

  1. Use num_workers > 0: Parallel data loading significantly speeds up training
  2. Enable pin_memory: Faster data transfer to GPU when using CUDA
  3. Optimize batch size: Larger batches improve GPU utilization (up to memory limits)
  4. Persistent workers: Set persistent_workers=True to avoid worker restart overhead
  5. Prefetch factor: Adjust prefetch_factor to control how many batches to load ahead

Multi-Process Best Practices

from torch.utils.data import DataLoader

# Optimal settings for fast training
loader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,              # Typically 4-8 workers
    pin_memory=True,            # Enable for GPU
    persistent_workers=True,    # Keep workers alive
    prefetch_factor=2           # Batches to prefetch per worker
)

Common Patterns

Concatenating Datasets

from torch.utils.data import ConcatDataset, DataLoader

# Combine multiple datasets
dataset1 = MyDataset(data1, labels1)
dataset2 = MyDataset(data2, labels2)

combined_dataset = ConcatDataset([dataset1, dataset2])
loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)

print(f"Combined size: {len(combined_dataset)}")

Subset of Dataset

from torch.utils.data import Subset, DataLoader
import numpy as np

# Use only a subset of data
indices = np.random.choice(len(dataset), size=1000, replace=False)
subset = Subset(dataset, indices)

loader = DataLoader(subset, batch_size=32, shuffle=True)
print(f"Subset size: {len(subset)}")

Next Steps

Neural Networks

Build models to train on your data

Tensors

Understand PyTorch’s fundamental data structure