Skip to main content

Neural Networks with torch.nn

The torch.nn module provides all the building blocks for creating neural networks in PyTorch. It includes layers, loss functions, and utilities for building complex architectures.

The nn.Module Class

All neural networks in PyTorch are subclasses of torch.nn.Module. This base class provides:
  • Parameter management
  • Layer registration
  • Easy model serialization
  • Hooks for debugging
  • GPU support
nn.Module automatically tracks all nn.Parameter and sub-modules defined in __init__, making it easy to access and optimize them.

Creating a Simple Network

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNet, self).__init__()
        
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # Define forward pass
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create model
model = SimpleNet(input_size=784, hidden_size=128, num_classes=10)

# Forward pass
x = torch.randn(32, 784)  # Batch of 32 samples
output = model(x)  # Shape: (32, 10)

print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

Core Layer Types

Linear (Fully Connected) Layers

import torch
import torch.nn as nn

# Linear layer: y = xW^T + b
layer = nn.Linear(in_features=20, out_features=30)

# Input shape: (batch_size, in_features)
x = torch.randn(128, 20)
output = layer(x)
# Output shape: (128, 30)

print(f"Weight shape: {layer.weight.shape}")  # (30, 20)
print(f"Bias shape: {layer.bias.shape}")      # (30,)

Convolutional Layers

Convolutional layers are essential for computer vision tasks:
import torch
import torch.nn as nn

# 2D Convolution for images
# Input: (batch, in_channels, height, width)
conv = nn.Conv2d(
    in_channels=3,      # RGB images
    out_channels=64,    # Number of filters
    kernel_size=3,      # 3x3 kernel
    stride=1,           # Stride
    padding=1           # Padding
)

# Input: batch of 32 RGB images (224x224)
x = torch.randn(32, 3, 224, 224)
output = conv(x)
# Output shape: (32, 64, 224, 224)

print(f"Output shape: {output.shape}")

Pooling Layers

Pooling layers reduce spatial dimensions:
import torch
import torch.nn as nn

# Max pooling
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

# Average pooling
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

# Adaptive pooling (output size specified)
adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))  # Always outputs 7x7

# Example
x = torch.randn(32, 64, 224, 224)
output = max_pool(x)
print(f"After pooling: {output.shape}")  # (32, 64, 112, 112)

# Adaptive pooling to fixed size
output = adaptive_pool(x)
print(f"After adaptive pooling: {output.shape}")  # (32, 64, 7, 7)

Activation Functions

import torch.nn as nn

# ReLU: max(0, x)
relu = nn.ReLU()

# ReLU in-place (saves memory)
relu_inplace = nn.ReLU(inplace=True)

# Using functional API
import torch.nn.functional as F
x = torch.randn(10)
output = F.relu(x)

Building Complex Architectures

Sequential Container

Chain layers together:
import torch.nn as nn

# Simple sequential model
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 10)
)

x = torch.randn(32, 784)
output = model(x)  # Shape: (32, 10)

# Access layers
first_layer = model[0]
print(first_layer.weight.shape)  # (256, 784)

ModuleList and ModuleDict

import torch
import torch.nn as nn

class DynamicNet(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        
        # List of layers
        self.layers = nn.ModuleList([
            nn.Linear(128, 128) for _ in range(num_layers)
        ])
        self.activation = nn.ReLU()
    
    def forward(self, x):
        for layer in self.layers:
            x = self.activation(layer(x))
        return x

model = DynamicNet(num_layers=5)
x = torch.randn(32, 128)
output = model(x)

ResNet-Style Residual Blocks

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        # Save input for residual connection
        residual = x
        
        # Main path
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # Add residual
        out += residual
        out = self.relu(out)
        
        return out

# Use in a network
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, padding=1),
    ResidualBlock(64),
    ResidualBlock(64),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(64, 10)
)

x = torch.randn(32, 3, 224, 224)
output = model(x)  # Shape: (32, 10)

Normalization Layers

Normalization layers help stabilize training by normalizing activations. BatchNorm is most common for CNNs, LayerNorm for Transformers, and GroupNorm when batch sizes are small.
import torch.nn as nn

# Batch Norm for 2D data (images)
bn2d = nn.BatchNorm2d(num_features=64)

# Batch Norm for 1D data (sequences, fully connected)
bn1d = nn.BatchNorm1d(num_features=128)

# In a CNN
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 128, kernel_size=3),
    nn.BatchNorm2d(128),
    nn.ReLU()
)

Regularization

Dropout

import torch
import torch.nn as nn

class NetWithDropout(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(784, 256)
        self.dropout1 = nn.Dropout(p=0.5)  # Drop 50% of neurons
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(p=0.3)  # Drop 30% of neurons
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)  # Applied only during training
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

model = NetWithDropout()

# Training mode (dropout active)
model.train()
x = torch.randn(32, 784)
output = model(x)

# Eval mode (dropout inactive)
model.eval()
output = model(x)

Loss Functions

PyTorch provides many built-in loss functions:
import torch
import torch.nn as nn

# Cross Entropy Loss (includes softmax)
criterion = nn.CrossEntropyLoss()

# Predictions (logits, no softmax needed)
logits = torch.randn(32, 10)  # 10 classes
# Targets (class indices)
targets = torch.randint(0, 10, (32,))

loss = criterion(logits, targets)
print(f"Loss: {loss.item():.4f}")

# Binary classification
bce_loss = nn.BCEWithLogitsLoss()
logits = torch.randn(32, 1)
targets = torch.randint(0, 2, (32, 1)).float()
loss = bce_loss(logits, targets)

Complete Training Example

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy data
train_data = TensorDataset(
    torch.randn(1000, 1, 28, 28),
    torch.randint(0, 10, (1000,))
)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        # Move to device
        data, target = data.to(device), target.to(device)
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# Evaluation
model.eval()
with torch.no_grad():
    # Evaluate on test data
    pass

Model Utilities

Saving and Loading Models

import torch

# Save
torch.save(model.state_dict(), 'model.pth')

# Load
model = CNN()
model.load_state_dict(torch.load('model.pth'))
model.eval()

Model Information

import torch
import torch.nn as nn

model = CNN()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# View all parameters
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

# View all submodules
for name, module in model.named_modules():
    print(f"{name}: {module.__class__.__name__}")

Best Practices

  1. Always call model.train() and model.eval(): This affects layers like Dropout and BatchNorm
  2. Use nn.Sequential for simple chains: For complex architectures, define custom Module classes
  3. Initialize weights properly: Use torch.nn.init for custom initialization
  4. Move models to GPU early: Call .to(device) right after creating the model
  5. Use appropriate loss functions: CrossEntropyLoss includes softmax, BCEWithLogitsLoss includes sigmoid

Next Steps

Data Loading

Learn efficient data loading with DataLoader

Automatic Differentiation

Understand PyTorch’s autograd system