Skip to main content

Overview

The torch.nn.init module provides various methods for initializing neural network parameters. Proper initialization is crucial for:
  • Preventing vanishing/exploding gradients
  • Faster convergence during training
  • Better final model performance
All initialization functions modify tensors in-place and return the modified tensor.

Uniform Distributions

uniform_

torch.nn.init.uniform_(tensor, a=0.0, b=1.0, generator=None)
Fills the input Tensor with values drawn from the uniform distribution U(a, b).
tensor
Tensor
required
An n-dimensional torch.Tensor
a
float
default:"0.0"
Lower bound of the uniform distribution
b
float
default:"1.0"
Upper bound of the uniform distribution
generator
torch.Generator
default:"None"
PyTorch Generator to sample from
Example:
import torch.nn as nn

w = torch.empty(3, 5)
nn.init.uniform_(w, a=-0.1, b=0.1)

Normal Distributions

normal_

torch.nn.init.normal_(tensor, mean=0.0, std=1.0, generator=None)
Fills the input Tensor with values drawn from the normal distribution N(mean, std²).
tensor
Tensor
required
An n-dimensional torch.Tensor
mean
float
default:"0.0"
Mean of the normal distribution
std
float
default:"1.0"
Standard deviation of the normal distribution
generator
torch.Generator
default:"None"
PyTorch Generator to sample from
Example:
w = torch.empty(3, 5)
nn.init.normal_(w, mean=0.0, std=0.01)

trunc_normal_

torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0, generator=None)
Fills the input Tensor with values drawn from a truncated normal distribution.
a
float
default:"-2.0"
Minimum cutoff value
b
float
default:"2.0"
Maximum cutoff value
Values are drawn from N(mean, std²) with values outside [a, b] redrawn until they are within bounds. Example:
w = torch.empty(3, 5)
nn.init.trunc_normal_(w, mean=0.0, std=0.02, a=-0.04, b=0.04)

Constant Initialization

constant_

torch.nn.init.constant_(tensor, val)
Fills the input Tensor with the value val.
tensor
Tensor
required
An n-dimensional torch.Tensor
val
float
required
Value to fill the tensor with
Example:
w = torch.empty(3, 5)
nn.init.constant_(w, 0.3)

ones_

torch.nn.init.ones_(tensor)
Fills the input Tensor with the scalar value 1. Example:
w = torch.empty(3, 5)
nn.init.ones_(w)

zeros_

torch.nn.init.zeros_(tensor)
Fills the input Tensor with the scalar value 0. Example:
w = torch.empty(3, 5)
nn.init.zeros_(w)

eye_

torch.nn.init.eye_(tensor)
Fills the 2-dimensional input Tensor with the identity matrix.
Tensor must be 2-dimensional.
Example:
w = torch.empty(3, 3)
nn.init.eye_(w)

Xavier Initialization

Xavier initialization (also known as Glorot initialization) helps maintain the variance of activations and gradients across layers.

xavier_uniform_

torch.nn.init.xavier_uniform_(tensor, gain=1.0, generator=None)
Fills the input Tensor with values according to the method described in “Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010).
tensor
Tensor
required
An n-dimensional torch.Tensor
gain
float
default:"1.0"
Optional scaling factor
Formula: Values are sampled from U(-a, a) where:
a = gain × sqrt(6 / (fan_in + fan_out))
Example:
linear = nn.Linear(20, 10)
nn.init.xavier_uniform_(linear.weight)

xavier_normal_

torch.nn.init.xavier_normal_(tensor, gain=1.0, generator=None)
Fills the input Tensor with values according to Xavier initialization using a normal distribution. Formula: Values are sampled from N(0, std²) where:
std = gain × sqrt(2 / (fan_in + fan_out))
Example:
linear = nn.Linear(20, 10)
nn.init.xavier_normal_(linear.weight)

Kaiming Initialization

Kaiming initialization (also known as He initialization) is designed for layers with ReLU activations.

kaiming_uniform_

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)
Fills the input Tensor with values according to the method described in “Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015).
tensor
Tensor
required
An n-dimensional torch.Tensor
a
float
default:"0"
Negative slope of the rectifier used after this layer (only used with 'leaky_relu')
mode
str
default:"'fan_in'"
Either 'fan_in' (default) or 'fan_out'. 'fan_in' preserves magnitude of variance in forward pass. 'fan_out' preserves magnitudes in backward pass.
nonlinearity
str
default:"'leaky_relu'"
Name of non-linear function: 'relu', 'leaky_relu', 'tanh', etc.
Formula: Values are sampled from U(-bound, bound) where:
bound = gain × sqrt(3 / fan_mode)
Example:
conv = nn.Conv2d(3, 64, kernel_size=3)
nn.init.kaiming_uniform_(conv.weight, mode='fan_in', nonlinearity='relu')

kaiming_normal_

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)
Fills the input Tensor with values according to Kaiming initialization using a normal distribution. Formula: Values are sampled from N(0, std²) where:
std = gain / sqrt(fan_mode)
Example:
conv = nn.Conv2d(3, 64, kernel_size=3)
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')

Special Initializations

orthogonal_

torch.nn.init.orthogonal_(tensor, gain=1, generator=None)
Fills the input Tensor with a (semi) orthogonal matrix, as described in “Exact solutions to the nonlinear dynamics of learning in deep linear neural networks” - Saxe, A. et al. (2013).
gain
float
default:"1"
Optional scaling factor
Example:
w = torch.empty(3, 5)
nn.init.orthogonal_(w)

sparse_

torch.nn.init.sparse_(tensor, sparsity, std=0.01, generator=None)
Fills the 2D input Tensor as a sparse matrix, where non-zero elements are drawn from a normal distribution.
tensor
Tensor
required
A 2-dimensional torch.Tensor
sparsity
float
required
Fraction of elements in each column to be set to zero
std
float
default:"0.01"
Standard deviation of the normal distribution for non-zero values
Example:
w = torch.empty(3, 5)
nn.init.sparse_(w, sparsity=0.1)

dirac_

torch.nn.init.dirac_(tensor, groups=1)
Fills the -dimensional input Tensor with the Dirac delta function.
groups
int
default:"1"
Number of groups for grouped convolutions
Preserves the identity of inputs in Convolutional layers, where possible. Example:
conv = nn.Conv2d(3, 3, 3)
nn.init.dirac_(conv.weight)

Calculate Gain

calculate_gain

torch.nn.init.calculate_gain(nonlinearity, param=None)
Returns the recommended gain value for the given nonlinearity function.
nonlinearity
str
required
Name of the non-linear function: 'linear', 'conv1d', 'conv2d', 'conv3d', 'sigmoid', 'tanh', 'relu', 'leaky_relu', 'selu'
param
float
default:"None"
Optional parameter for the non-linear function (e.g., negative_slope for leaky_relu)
Gain Values:
NonlinearityGain
Linear/Identity/Conv1
Sigmoid1
Tanh5/3
ReLU√2
Leaky ReLU√(2/(1+negative_slope²))
SELU3/4
Example:
gain = nn.init.calculate_gain('relu')
print(gain)  # 1.4142135623730951 (√2)

gain = nn.init.calculate_gain('leaky_relu', 0.2)
print(gain)  # 1.3867504905630728

Usage Patterns

Custom Module Initialization

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.fc = nn.Linear(128, 10)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

Layer-specific Initialization

# For ReLU activation layers
conv = nn.Conv2d(3, 64, 3)
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')

# For Tanh activation layers
linear = nn.Linear(100, 50)
gain = nn.init.calculate_gain('tanh')
nn.init.xavier_normal_(linear.weight, gain=gain)

# Bias initialization
nn.init.constant_(conv.bias, 0)

Conditional Initialization

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model = MyModel()
model.apply(init_weights)

Best Practices

Use Kaiming initialization for ReLU-based networks and Xavier for tanh/sigmoid:
# For ReLU networks
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')

# For Tanh networks
nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('tanh'))
# Zero initialization (common)
nn.init.constant_(layer.bias, 0)

# Small positive value for ReLU to avoid dead neurons
nn.init.constant_(layer.bias, 0.01)
# Better for conv layers in modern architectures
nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
Different initialization strategies work better for different architectures:
  • ResNets: Kaiming initialization with fan_out
  • Transformers: Xavier uniform or normal
  • GANs: Normal with small std (0.02) or orthogonal

Common Recipes

def init_resnet(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

model.apply(init_resnet)