Overview
Datasets are the fundamental abstraction for representing data in PyTorch. All datasets inherit from torch.utils.data.Dataset and define how to access individual samples.
There are two main types:
Map-style datasets - Index-based access via __getitem__ and __len__
Iterable datasets - Stream-based access via __iter__
Dataset (Map-Style)
Base class for all map-style datasets.
class torch.utils.data.Dataset
Abstract Methods
Fetch a data sample for a given key/index
Return the size of the dataset. Expected by many samplers and DataLoader
Example - Custom Dataset
import torch
from torch.utils.data import Dataset
class CustomDataset ( Dataset ):
def __init__ ( self , data , labels ):
self .data = data
self .labels = labels
def __len__ ( self ):
return len ( self .data)
def __getitem__ ( self , idx ):
sample = self .data[idx]
label = self .labels[idx]
return sample, label
# Usage
data = torch.randn( 100 , 3 , 32 , 32 )
labels = torch.randint( 0 , 10 , ( 100 ,))
dataset = CustomDataset(data, labels)
print ( len (dataset)) # 100
sample, label = dataset[ 0 ] # Get first item
class ImageDataset ( Dataset ):
def __init__ ( self , image_paths , labels , transform = None ):
self .image_paths = image_paths
self .labels = labels
self .transform = transform
def __len__ ( self ):
return len ( self .image_paths)
def __getitem__ ( self , idx ):
image = load_image( self .image_paths[idx])
label = self .labels[idx]
if self .transform:
image = self .transform(image)
return image, label
IterableDataset
Base class for iterable-style datasets.
class torch.utils.data.IterableDataset
Abstract Methods
Return an iterator over data samples
Example - Stream Processing
from torch.utils.data import IterableDataset
class StreamDataset ( IterableDataset ):
def __init__ ( self , start , end ):
self .start = start
self .end = end
def __iter__ ( self ):
return iter ( range ( self .start, self .end))
dataset = StreamDataset( 0 , 100 )
for sample in dataset:
print (sample) # 0, 1, 2, ..., 99
Example - Multi-Worker Support
class WorkerAwareDataset ( IterableDataset ):
def __init__ ( self , start , end ):
self .start = start
self .end = end
def __iter__ ( self ):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None :
# Single-process loading
iter_start = self .start
iter_end = self .end
else :
# Multi-process loading: split workload
per_worker = int (math.ceil(( self .end - self .start) / worker_info.num_workers))
worker_id = worker_info.id
iter_start = self .start + worker_id * per_worker
iter_end = min (iter_start + per_worker, self .end)
return iter ( range (iter_start, iter_end))
TensorDataset
Dataset wrapping tensors. Each sample is retrieved by indexing tensors along the first dimension.
torch.utils.data.TensorDataset( * tensors)
Parameters
Tensors that have the same size of the first dimension
Example
from torch.utils.data import TensorDataset, DataLoader
data = torch.randn( 100 , 3 , 32 , 32 )
labels = torch.randint( 0 , 10 , ( 100 ,))
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size = 10 )
for batch_data, batch_labels in loader:
print (batch_data.shape) # torch.Size([10, 3, 32, 32])
print (batch_labels.shape) # torch.Size([10])
ConcatDataset
Dataset as a concatenation of multiple datasets.
torch.utils.data.ConcatDataset(datasets)
Parameters
List of datasets to be concatenated
Example
from torch.utils.data import ConcatDataset
dataset1 = TensorDataset(torch.randn( 50 , 3 , 32 , 32 ), torch.randint( 0 , 10 , ( 50 ,)))
dataset2 = TensorDataset(torch.randn( 50 , 3 , 32 , 32 ), torch.randint( 0 , 10 , ( 50 ,)))
combined = ConcatDataset([dataset1, dataset2])
print ( len (combined)) # 100
# Can also use + operator
combined = dataset1 + dataset2
Subset
Subset of a dataset at specified indices.
torch.utils.data.Subset(dataset, indices)
Parameters
Indices in the whole dataset selected for subset
Example
from torch.utils.data import Subset
full_dataset = TensorDataset(torch.randn( 100 , 3 , 32 , 32 ), torch.randint( 0 , 10 , ( 100 ,)))
# Create subset with first 50 samples
indices = list ( range ( 50 ))
subset = Subset(full_dataset, indices)
print ( len (subset)) # 50
random_split
Randomly split a dataset into non-overlapping new datasets of given lengths.
torch.utils.data.random_split(dataset, lengths, generator = None )
Parameters
Lengths or fractions of splits to be produced. Fractions should sum to 1
Generator used for random permutation
Example - Fixed Lengths
from torch.utils.data import random_split
dataset = TensorDataset(torch.randn( 100 , 3 , 32 , 32 ), torch.randint( 0 , 10 , ( 100 ,)))
# Split into 80/20 train/val
train_dataset, val_dataset = random_split(dataset, [ 80 , 20 ])
print ( len (train_dataset)) # 80
print ( len (val_dataset)) # 20
Example - Fractions
# Split using fractions
train_dataset, val_dataset, test_dataset = random_split(
dataset,
[ 0.7 , 0.2 , 0.1 ] # 70% train, 20% val, 10% test
)
Example - Reproducible Split
generator = torch.Generator().manual_seed( 42 )
train_dataset, val_dataset = random_split(
dataset,
[ 0.8 , 0.2 ],
generator = generator
)
ChainDataset
Dataset for chaining multiple IterableDataset objects.
torch.utils.data.ChainDataset(datasets)
Parameters
Iterable datasets to be chained together
Example
from torch.utils.data import ChainDataset, IterableDataset
class RangeDataset ( IterableDataset ):
def __init__ ( self , start , end ):
self .start = start
self .end = end
def __iter__ ( self ):
return iter ( range ( self .start, self .end))
dataset1 = RangeDataset( 0 , 5 )
dataset2 = RangeDataset( 5 , 10 )
chained = ChainDataset([dataset1, dataset2])
for item in chained:
print (item) # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
StackDataset
Dataset as a stacking of multiple datasets.
torch.utils.data.StackDataset( * args, ** kwargs)
Parameters
Datasets for stacking returned as tuple
Datasets for stacking returned as dict
Example - Tuple Output
from torch.utils.data import StackDataset, TensorDataset
images = TensorDataset(torch.randn( 100 , 3 , 32 , 32 ))
labels = TensorDataset(torch.randint( 0 , 10 , ( 100 ,)))
stacked = StackDataset(images, labels)
item = stacked[ 0 ] # Returns tuple: (image, label)
Example - Dict Output
stacked = StackDataset( image = images, label = labels)
item = stacked[ 0 ] # Returns dict: {"image": ..., "label": ...}
Common Patterns
Dataset with Caching
class CachedDataset ( Dataset ):
def __init__ ( self , base_dataset ):
self .base_dataset = base_dataset
self .cache = {}
def __len__ ( self ):
return len ( self .base_dataset)
def __getitem__ ( self , idx ):
if idx not in self .cache:
self .cache[idx] = self .base_dataset[idx]
return self .cache[idx]
Dataset with Lazy Loading
class LazyLoadDataset ( Dataset ):
def __init__ ( self , file_paths ):
self .file_paths = file_paths
def __len__ ( self ):
return len ( self .file_paths)
def __getitem__ ( self , idx ):
# Load data only when requested
data = load_from_file( self .file_paths[idx])
return data
Dataset with On-the-Fly Augmentation
class AugmentedDataset ( Dataset ):
def __init__ ( self , base_dataset , augmentations ):
self .base_dataset = base_dataset
self .augmentations = augmentations
def __len__ ( self ):
return len ( self .base_dataset)
def __getitem__ ( self , idx ):
data, label = self .base_dataset[idx]
# Apply random augmentations
for aug in self .augmentations:
data = aug(data)
return data, label
Multi-Modal Dataset
class MultiModalDataset ( Dataset ):
def __init__ ( self , image_paths , text_data , labels ):
self .image_paths = image_paths
self .text_data = text_data
self .labels = labels
def __len__ ( self ):
return len ( self .labels)
def __getitem__ ( self , idx ):
image = load_image( self .image_paths[idx])
text = self .text_data[idx]
label = self .labels[idx]
return {
'image' : image,
'text' : text,
'label' : label
}
Best Practices
Don’t load all data in __init__ for large datasets
Load samples on-demand in __getitem__
Use memory mapping for very large files
Clear caches periodically if using caching
Use fixed random seeds for reproducibility
Pass generator to random_split
Document any randomness in transforms
Test __len__ returns correct value
Test __getitem__ with edge indices (0, -1, len-1)
Verify data types and shapes
Check for data leakage in splits
See Also