Skip to main content

Overview

Datasets are the fundamental abstraction for representing data in PyTorch. All datasets inherit from torch.utils.data.Dataset and define how to access individual samples. There are two main types:
  • Map-style datasets - Index-based access via __getitem__ and __len__
  • Iterable datasets - Stream-based access via __iter__

Dataset (Map-Style)

Base class for all map-style datasets.
class torch.utils.data.Dataset

Abstract Methods

__getitem__(index)
method
required
Fetch a data sample for a given key/index
__len__()
method
Return the size of the dataset. Expected by many samplers and DataLoader

Example - Custom Dataset

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# Usage
data = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))
dataset = CustomDataset(data, labels)

print(len(dataset))  # 100
sample, label = dataset[0]  # Get first item

Example - With Transforms

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = load_image(self.image_paths[idx])
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

IterableDataset

Base class for iterable-style datasets.
class torch.utils.data.IterableDataset

Abstract Methods

__iter__()
method
required
Return an iterator over data samples

Example - Stream Processing

from torch.utils.data import IterableDataset

class StreamDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end
    
    def __iter__(self):
        return iter(range(self.start, self.end))

dataset = StreamDataset(0, 100)
for sample in dataset:
    print(sample)  # 0, 1, 2, ..., 99

Example - Multi-Worker Support

class WorkerAwareDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        
        if worker_info is None:
            # Single-process loading
            iter_start = self.start
            iter_end = self.end
        else:
            # Multi-process loading: split workload
            per_worker = int(math.ceil((self.end - self.start) / worker_info.num_workers))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        
        return iter(range(iter_start, iter_end))

TensorDataset

Dataset wrapping tensors. Each sample is retrieved by indexing tensors along the first dimension.
torch.utils.data.TensorDataset(*tensors)

Parameters

*tensors
Tensor
required
Tensors that have the same size of the first dimension

Example

from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(100, 3, 32, 32)
labels = torch.randint(0, 10, (100,))

dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=10)

for batch_data, batch_labels in loader:
    print(batch_data.shape)  # torch.Size([10, 3, 32, 32])
    print(batch_labels.shape)  # torch.Size([10])

ConcatDataset

Dataset as a concatenation of multiple datasets.
torch.utils.data.ConcatDataset(datasets)

Parameters

datasets
sequence
required
List of datasets to be concatenated

Example

from torch.utils.data import ConcatDataset

dataset1 = TensorDataset(torch.randn(50, 3, 32, 32), torch.randint(0, 10, (50,)))
dataset2 = TensorDataset(torch.randn(50, 3, 32, 32), torch.randint(0, 10, (50,)))

combined = ConcatDataset([dataset1, dataset2])
print(len(combined))  # 100

# Can also use + operator
combined = dataset1 + dataset2

Subset

Subset of a dataset at specified indices.
torch.utils.data.Subset(dataset, indices)

Parameters

dataset
Dataset
required
The whole dataset
indices
sequence
required
Indices in the whole dataset selected for subset

Example

from torch.utils.data import Subset

full_dataset = TensorDataset(torch.randn(100, 3, 32, 32), torch.randint(0, 10, (100,)))

# Create subset with first 50 samples
indices = list(range(50))
subset = Subset(full_dataset, indices)

print(len(subset))  # 50

random_split

Randomly split a dataset into non-overlapping new datasets of given lengths.
torch.utils.data.random_split(dataset, lengths, generator=None)

Parameters

dataset
Dataset
required
Dataset to be split
lengths
sequence
required
Lengths or fractions of splits to be produced. Fractions should sum to 1
generator
Generator
Generator used for random permutation

Example - Fixed Lengths

from torch.utils.data import random_split

dataset = TensorDataset(torch.randn(100, 3, 32, 32), torch.randint(0, 10, (100,)))

# Split into 80/20 train/val
train_dataset, val_dataset = random_split(dataset, [80, 20])

print(len(train_dataset))  # 80
print(len(val_dataset))    # 20

Example - Fractions

# Split using fractions
train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [0.7, 0.2, 0.1]  # 70% train, 20% val, 10% test
)

Example - Reproducible Split

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(
    dataset,
    [0.8, 0.2],
    generator=generator
)

ChainDataset

Dataset for chaining multiple IterableDataset objects.
torch.utils.data.ChainDataset(datasets)

Parameters

datasets
iterable
required
Iterable datasets to be chained together

Example

from torch.utils.data import ChainDataset, IterableDataset

class RangeDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end
    
    def __iter__(self):
        return iter(range(self.start, self.end))

dataset1 = RangeDataset(0, 5)
dataset2 = RangeDataset(5, 10)

chained = ChainDataset([dataset1, dataset2])
for item in chained:
    print(item)  # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9

StackDataset

Dataset as a stacking of multiple datasets.
torch.utils.data.StackDataset(*args, **kwargs)

Parameters

*args
Dataset
Datasets for stacking returned as tuple
**kwargs
Dataset
Datasets for stacking returned as dict

Example - Tuple Output

from torch.utils.data import StackDataset, TensorDataset

images = TensorDataset(torch.randn(100, 3, 32, 32))
labels = TensorDataset(torch.randint(0, 10, (100,)))

stacked = StackDataset(images, labels)
item = stacked[0]  # Returns tuple: (image, label)

Example - Dict Output

stacked = StackDataset(image=images, label=labels)
item = stacked[0]  # Returns dict: {"image": ..., "label": ...}

Common Patterns

Dataset with Caching

class CachedDataset(Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.cache = {}
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        if idx not in self.cache:
            self.cache[idx] = self.base_dataset[idx]
        return self.cache[idx]

Dataset with Lazy Loading

class LazyLoadDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        # Load data only when requested
        data = load_from_file(self.file_paths[idx])
        return data

Dataset with On-the-Fly Augmentation

class AugmentedDataset(Dataset):
    def __init__(self, base_dataset, augmentations):
        self.base_dataset = base_dataset
        self.augmentations = augmentations
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        data, label = self.base_dataset[idx]
        
        # Apply random augmentations
        for aug in self.augmentations:
            data = aug(data)
        
        return data, label

Multi-Modal Dataset

class MultiModalDataset(Dataset):
    def __init__(self, image_paths, text_data, labels):
        self.image_paths = image_paths
        self.text_data = text_data
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = load_image(self.image_paths[idx])
        text = self.text_data[idx]
        label = self.labels[idx]
        
        return {
            'image': image,
            'text': text,
            'label': label
        }

Best Practices

  • Don’t load all data in __init__ for large datasets
  • Load samples on-demand in __getitem__
  • Use memory mapping for very large files
  • Clear caches periodically if using caching
  • Optimize __getitem__ as it’s called frequently
  • Preprocess data offline when possible
  • Use efficient file formats (HDF5, LMDB)
  • Profile your dataset to identify bottlenecks
  • Use fixed random seeds for reproducibility
  • Pass generator to random_split
  • Document any randomness in transforms
  • Test __len__ returns correct value
  • Test __getitem__ with edge indices (0, -1, len-1)
  • Verify data types and shapes
  • Check for data leakage in splits

See Also