Neural Networks with torch.nn
Thetorch.nn module provides all the building blocks for creating neural networks in PyTorch. It includes layers, loss functions, and utilities for building complex architectures.
The nn.Module Class
All neural networks in PyTorch are subclasses oftorch.nn.Module. This base class provides:
- Parameter management
- Layer registration
- Easy model serialization
- Hooks for debugging
- GPU support
nn.Module automatically tracks all nn.Parameter and sub-modules defined in __init__, making it easy to access and optimize them.Creating a Simple Network
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNet, self).__init__()
# Define layers
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Define forward pass
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Create model
model = SimpleNet(input_size=784, hidden_size=128, num_classes=10)
# Forward pass
x = torch.randn(32, 784) # Batch of 32 samples
output = model(x) # Shape: (32, 10)
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
Core Layer Types
Linear (Fully Connected) Layers
import torch
import torch.nn as nn
# Linear layer: y = xW^T + b
layer = nn.Linear(in_features=20, out_features=30)
# Input shape: (batch_size, in_features)
x = torch.randn(128, 20)
output = layer(x)
# Output shape: (128, 30)
print(f"Weight shape: {layer.weight.shape}") # (30, 20)
print(f"Bias shape: {layer.bias.shape}") # (30,)
Convolutional Layers
Convolutional layers are essential for computer vision tasks:- Conv2d
- Conv1d
- Conv3d
import torch
import torch.nn as nn
# 2D Convolution for images
# Input: (batch, in_channels, height, width)
conv = nn.Conv2d(
in_channels=3, # RGB images
out_channels=64, # Number of filters
kernel_size=3, # 3x3 kernel
stride=1, # Stride
padding=1 # Padding
)
# Input: batch of 32 RGB images (224x224)
x = torch.randn(32, 3, 224, 224)
output = conv(x)
# Output shape: (32, 64, 224, 224)
print(f"Output shape: {output.shape}")
import torch
import torch.nn as nn
# 1D Convolution for sequences
# Input: (batch, in_channels, length)
conv = nn.Conv1d(
in_channels=16,
out_channels=32,
kernel_size=3
)
# Input: batch of 64 sequences
x = torch.randn(64, 16, 100)
output = conv(x)
# Output shape: (64, 32, 98)
import torch
import torch.nn as nn
# 3D Convolution for videos/volumetric data
# Input: (batch, in_channels, depth, height, width)
conv = nn.Conv3d(
in_channels=3,
out_channels=16,
kernel_size=3
)
# Input: batch of 8 video clips
x = torch.randn(8, 3, 16, 112, 112) # 16 frames, 112x112
output = conv(x)
# Output shape: (8, 16, 14, 110, 110)
Pooling Layers
Pooling layers reduce spatial dimensions:import torch
import torch.nn as nn
# Max pooling
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Average pooling
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
# Adaptive pooling (output size specified)
adaptive_pool = nn.AdaptiveAvgPool2d((7, 7)) # Always outputs 7x7
# Example
x = torch.randn(32, 64, 224, 224)
output = max_pool(x)
print(f"After pooling: {output.shape}") # (32, 64, 112, 112)
# Adaptive pooling to fixed size
output = adaptive_pool(x)
print(f"After adaptive pooling: {output.shape}") # (32, 64, 7, 7)
Activation Functions
import torch.nn as nn
# ReLU: max(0, x)
relu = nn.ReLU()
# ReLU in-place (saves memory)
relu_inplace = nn.ReLU(inplace=True)
# Using functional API
import torch.nn.functional as F
x = torch.randn(10)
output = F.relu(x)
Building Complex Architectures
Sequential Container
Chain layers together:import torch.nn as nn
# Simple sequential model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 10)
)
x = torch.randn(32, 784)
output = model(x) # Shape: (32, 10)
# Access layers
first_layer = model[0]
print(first_layer.weight.shape) # (256, 784)
ModuleList and ModuleDict
- ModuleList
- ModuleDict
import torch
import torch.nn as nn
class DynamicNet(nn.Module):
def __init__(self, num_layers):
super().__init__()
# List of layers
self.layers = nn.ModuleList([
nn.Linear(128, 128) for _ in range(num_layers)
])
self.activation = nn.ReLU()
def forward(self, x):
for layer in self.layers:
x = self.activation(layer(x))
return x
model = DynamicNet(num_layers=5)
x = torch.randn(32, 128)
output = model(x)
import torch
import torch.nn as nn
class MultiTaskNet(nn.Module):
def __init__(self):
super().__init__()
self.shared = nn.Linear(128, 64)
# Dictionary of task-specific heads
self.heads = nn.ModuleDict({
'classification': nn.Linear(64, 10),
'regression': nn.Linear(64, 1),
'segmentation': nn.Linear(64, 21)
})
def forward(self, x, task='classification'):
x = self.shared(x)
return self.heads[task](x)
model = MultiTaskNet()
x = torch.randn(32, 128)
# Different tasks
cls_output = model(x, task='classification')
reg_output = model(x, task='regression')
ResNet-Style Residual Blocks
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# Save input for residual connection
residual = x
# Main path
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# Add residual
out += residual
out = self.relu(out)
return out
# Use in a network
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
ResidualBlock(64),
ResidualBlock(64),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(64, 10)
)
x = torch.randn(32, 3, 224, 224)
output = model(x) # Shape: (32, 10)
Normalization Layers
Why Normalization Matters
Why Normalization Matters
Normalization layers help stabilize training by normalizing activations. BatchNorm is most common for CNNs, LayerNorm for Transformers, and GroupNorm when batch sizes are small.
import torch.nn as nn
# Batch Norm for 2D data (images)
bn2d = nn.BatchNorm2d(num_features=64)
# Batch Norm for 1D data (sequences, fully connected)
bn1d = nn.BatchNorm1d(num_features=128)
# In a CNN
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU()
)
Regularization
Dropout
import torch
import torch.nn as nn
class NetWithDropout(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.dropout1 = nn.Dropout(p=0.5) # Drop 50% of neurons
self.fc2 = nn.Linear(256, 128)
self.dropout2 = nn.Dropout(p=0.3) # Drop 30% of neurons
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout1(x) # Applied only during training
x = F.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return x
model = NetWithDropout()
# Training mode (dropout active)
model.train()
x = torch.randn(32, 784)
output = model(x)
# Eval mode (dropout inactive)
model.eval()
output = model(x)
Loss Functions
PyTorch provides many built-in loss functions:- Classification
- Regression
- Custom Loss
import torch
import torch.nn as nn
# Cross Entropy Loss (includes softmax)
criterion = nn.CrossEntropyLoss()
# Predictions (logits, no softmax needed)
logits = torch.randn(32, 10) # 10 classes
# Targets (class indices)
targets = torch.randint(0, 10, (32,))
loss = criterion(logits, targets)
print(f"Loss: {loss.item():.4f}")
# Binary classification
bce_loss = nn.BCEWithLogitsLoss()
logits = torch.randn(32, 1)
targets = torch.randint(0, 2, (32, 1)).float()
loss = bce_loss(logits, targets)
import torch
import torch.nn as nn
# Mean Squared Error
mse_loss = nn.MSELoss()
predictions = torch.randn(32, 1)
targets = torch.randn(32, 1)
loss = mse_loss(predictions, targets)
# Mean Absolute Error (L1 Loss)
l1_loss = nn.L1Loss()
loss = l1_loss(predictions, targets)
# Smooth L1 Loss (Huber Loss)
smooth_l1 = nn.SmoothL1Loss()
loss = smooth_l1(predictions, targets)
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ce = nn.CrossEntropyLoss(reduction='none')
def forward(self, inputs, targets):
ce_loss = self.ce(inputs, targets)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
# Use custom loss
criterion = FocalLoss()
logits = torch.randn(32, 10)
targets = torch.randint(0, 10, (32,))
loss = criterion(logits, targets)
Complete Training Example
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Define model
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.conv_layers(x)
x = self.fc_layers(x)
return x
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Dummy data
train_data = TensorDataset(
torch.randn(1000, 1, 28, 28),
torch.randint(0, 10, (1000,))
)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
# Move to device
data, target = data.to(device), target.to(device)
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
# Evaluation
model.eval()
with torch.no_grad():
# Evaluate on test data
pass
Model Utilities
Saving and Loading Models
import torch
# Save
torch.save(model.state_dict(), 'model.pth')
# Load
model = CNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()
Model Information
import torch
import torch.nn as nn
model = CNN()
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# View all parameters
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
# View all submodules
for name, module in model.named_modules():
print(f"{name}: {module.__class__.__name__}")
Best Practices
- Always call model.train() and model.eval(): This affects layers like Dropout and BatchNorm
- Use nn.Sequential for simple chains: For complex architectures, define custom Module classes
- Initialize weights properly: Use
torch.nn.initfor custom initialization - Move models to GPU early: Call
.to(device)right after creating the model - Use appropriate loss functions: CrossEntropyLoss includes softmax, BCEWithLogitsLoss includes sigmoid
Next Steps
Data Loading
Learn efficient data loading with DataLoader
Automatic Differentiation
Understand PyTorch’s autograd system