Overview
The torch.nn.init module provides various methods for initializing neural network parameters. Proper initialization is crucial for:
Preventing vanishing/exploding gradients
Faster convergence during training
Better final model performance
All initialization functions modify tensors in-place and return the modified tensor.
torch.nn.init.uniform_(tensor, a = 0.0 , b = 1.0 , generator = None )
Fills the input Tensor with values drawn from the uniform distribution U(a, b).
An n-dimensional torch.Tensor
Lower bound of the uniform distribution
Upper bound of the uniform distribution
generator
torch.Generator
default: "None"
PyTorch Generator to sample from
Example:
import torch.nn as nn
w = torch.empty( 3 , 5 )
nn.init.uniform_(w, a =- 0.1 , b = 0.1 )
Normal Distributions
normal_
torch.nn.init.normal_(tensor, mean = 0.0 , std = 1.0 , generator = None )
Fills the input Tensor with values drawn from the normal distribution N(mean, std²).
An n-dimensional torch.Tensor
Mean of the normal distribution
Standard deviation of the normal distribution
generator
torch.Generator
default: "None"
PyTorch Generator to sample from
Example:
w = torch.empty( 3 , 5 )
nn.init.normal_(w, mean = 0.0 , std = 0.01 )
trunc_normal_
torch.nn.init.trunc_normal_(tensor, mean = 0.0 , std = 1.0 , a =- 2.0 , b = 2.0 , generator = None )
Fills the input Tensor with values drawn from a truncated normal distribution.
Values are drawn from N(mean, std²) with values outside [a, b] redrawn until they are within bounds.
Example:
w = torch.empty( 3 , 5 )
nn.init.trunc_normal_(w, mean = 0.0 , std = 0.02 , a =- 0.04 , b = 0.04 )
Constant Initialization
constant_
torch.nn.init.constant_(tensor, val)
Fills the input Tensor with the value val.
An n-dimensional torch.Tensor
Value to fill the tensor with
Example:
w = torch.empty( 3 , 5 )
nn.init.constant_(w, 0.3 )
ones_
torch.nn.init.ones_(tensor)
Fills the input Tensor with the scalar value 1.
Example:
w = torch.empty( 3 , 5 )
nn.init.ones_(w)
zeros_
torch.nn.init.zeros_(tensor)
Fills the input Tensor with the scalar value 0.
Example:
w = torch.empty( 3 , 5 )
nn.init.zeros_(w)
eye_
torch.nn.init.eye_(tensor)
Fills the 2-dimensional input Tensor with the identity matrix.
Tensor must be 2-dimensional.
Example:
w = torch.empty( 3 , 3 )
nn.init.eye_(w)
Xavier Initialization
Xavier initialization (also known as Glorot initialization) helps maintain the variance of activations and gradients across layers.
torch.nn.init.xavier_uniform_(tensor, gain = 1.0 , generator = None )
Fills the input Tensor with values according to the method described in “Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010).
An n-dimensional torch.Tensor
Formula:
Values are sampled from U(-a, a) where:
a = gain × sqrt(6 / (fan_in + fan_out))
Example:
linear = nn.Linear( 20 , 10 )
nn.init.xavier_uniform_(linear.weight)
xavier_normal_
torch.nn.init.xavier_normal_(tensor, gain = 1.0 , generator = None )
Fills the input Tensor with values according to Xavier initialization using a normal distribution.
Formula:
Values are sampled from N(0, std²) where:
std = gain × sqrt(2 / (fan_in + fan_out))
Example:
linear = nn.Linear( 20 , 10 )
nn.init.xavier_normal_(linear.weight)
Kaiming Initialization
Kaiming initialization (also known as He initialization) is designed for layers with ReLU activations.
torch.nn.init.kaiming_uniform_(tensor, a = 0 , mode = 'fan_in' , nonlinearity = 'leaky_relu' , generator = None )
Fills the input Tensor with values according to the method described in “Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015).
An n-dimensional torch.Tensor
Negative slope of the rectifier used after this layer (only used with 'leaky_relu')
Either 'fan_in' (default) or 'fan_out'. 'fan_in' preserves magnitude of variance in forward pass. 'fan_out' preserves magnitudes in backward pass.
nonlinearity
str
default: "'leaky_relu'"
Name of non-linear function: 'relu', 'leaky_relu', 'tanh', etc.
Formula:
Values are sampled from U(-bound, bound) where:
bound = gain × sqrt(3 / fan_mode)
Example:
conv = nn.Conv2d( 3 , 64 , kernel_size = 3 )
nn.init.kaiming_uniform_(conv.weight, mode = 'fan_in' , nonlinearity = 'relu' )
kaiming_normal_
torch.nn.init.kaiming_normal_(tensor, a = 0 , mode = 'fan_in' , nonlinearity = 'leaky_relu' , generator = None )
Fills the input Tensor with values according to Kaiming initialization using a normal distribution.
Formula:
Values are sampled from N(0, std²) where:
std = gain / sqrt(fan_mode)
Example:
conv = nn.Conv2d( 3 , 64 , kernel_size = 3 )
nn.init.kaiming_normal_(conv.weight, mode = 'fan_out' , nonlinearity = 'relu' )
Special Initializations
orthogonal_
torch.nn.init.orthogonal_(tensor, gain = 1 , generator = None )
Fills the input Tensor with a (semi) orthogonal matrix, as described in “Exact solutions to the nonlinear dynamics of learning in deep linear neural networks” - Saxe, A. et al. (2013).
Example:
w = torch.empty( 3 , 5 )
nn.init.orthogonal_(w)
sparse_
torch.nn.init.sparse_(tensor, sparsity, std = 0.01 , generator = None )
Fills the 2D input Tensor as a sparse matrix, where non-zero elements are drawn from a normal distribution.
A 2-dimensional torch.Tensor
Fraction of elements in each column to be set to zero
Standard deviation of the normal distribution for non-zero values
Example:
w = torch.empty( 3 , 5 )
nn.init.sparse_(w, sparsity = 0.1 )
dirac_
torch.nn.init.dirac_(tensor, groups = 1 )
Fills the -dimensional input Tensor with the Dirac delta function.
Number of groups for grouped convolutions
Preserves the identity of inputs in Convolutional layers, where possible.
Example:
conv = nn.Conv2d( 3 , 3 , 3 )
nn.init.dirac_(conv.weight)
Calculate Gain
calculate_gain
torch.nn.init.calculate_gain(nonlinearity, param = None )
Returns the recommended gain value for the given nonlinearity function.
Name of the non-linear function: 'linear', 'conv1d', 'conv2d', 'conv3d', 'sigmoid', 'tanh', 'relu', 'leaky_relu', 'selu'
Optional parameter for the non-linear function (e.g., negative_slope for leaky_relu)
Gain Values:
Nonlinearity Gain Linear/Identity/Conv 1 Sigmoid 1 Tanh 5/3 ReLU √2 Leaky ReLU √(2/(1+negative_slope²)) SELU 3/4
Example:
gain = nn.init.calculate_gain( 'relu' )
print (gain) # 1.4142135623730951 (√2)
gain = nn.init.calculate_gain( 'leaky_relu' , 0.2 )
print (gain) # 1.3867504905630728
Usage Patterns
Custom Module Initialization
class MyModel ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self .conv1 = nn.Conv2d( 3 , 64 , 3 )
self .conv2 = nn.Conv2d( 64 , 128 , 3 )
self .fc = nn.Linear( 128 , 10 )
# Initialize weights
self ._initialize_weights()
def _initialize_weights ( self ):
for m in self .modules():
if isinstance (m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode = 'fan_out' , nonlinearity = 'relu' )
if m.bias is not None :
nn.init.constant_(m.bias, 0 )
elif isinstance (m, nn.Linear):
nn.init.normal_(m.weight, 0 , 0.01 )
nn.init.constant_(m.bias, 0 )
Layer-specific Initialization
# For ReLU activation layers
conv = nn.Conv2d( 3 , 64 , 3 )
nn.init.kaiming_normal_(conv.weight, mode = 'fan_out' , nonlinearity = 'relu' )
# For Tanh activation layers
linear = nn.Linear( 100 , 50 )
gain = nn.init.calculate_gain( 'tanh' )
nn.init.xavier_normal_(linear.weight, gain = gain)
# Bias initialization
nn.init.constant_(conv.bias, 0 )
Conditional Initialization
def init_weights ( m ):
if isinstance (m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_( 0.01 )
elif isinstance (m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode = 'fan_out' )
if m.bias is not None :
nn.init.constant_(m.bias, 0 )
model = MyModel()
model.apply(init_weights)
Best Practices
Match initialization to activation function
Use Kaiming initialization for ReLU-based networks and Xavier for tanh/sigmoid: # For ReLU networks
nn.init.kaiming_normal_(layer.weight, nonlinearity = 'relu' )
# For Tanh networks
nn.init.xavier_normal_(layer.weight, gain = nn.init.calculate_gain( 'tanh' ))
Initialize biases to zero or small constants
# Zero initialization (common)
nn.init.constant_(layer.bias, 0 )
# Small positive value for ReLU to avoid dead neurons
nn.init.constant_(layer.bias, 0.01 )
Use fan_out for convolutional layers
# Better for conv layers in modern architectures
nn.init.kaiming_normal_(conv.weight, mode = 'fan_out' , nonlinearity = 'relu' )
Consider layer type and architecture
Different initialization strategies work better for different architectures:
ResNets : Kaiming initialization with fan_out
Transformers : Xavier uniform or normal
GANs : Normal with small std (0.02) or orthogonal
Common Recipes
ResNet-style
Transformer-style
GAN-style
def init_resnet ( m ):
if isinstance (m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode = 'fan_out' , nonlinearity = 'relu' )
elif isinstance (m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1 )
nn.init.constant_(m.bias, 0 )
model.apply(init_resnet)
def init_transformer ( m ):
if isinstance (m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None :
nn.init.constant_(m.bias, 0 )
elif isinstance (m, nn.Embedding):
nn.init.normal_(m.weight, mean = 0 , std = 0.02 )
model.apply(init_transformer)
def init_gan ( m ):
classname = m. __class__ . __name__
if classname.find( 'Conv' ) != - 1 :
nn.init.normal_(m.weight.data, 0.0 , 0.02 )
elif classname.find( 'BatchNorm' ) != - 1 :
nn.init.normal_(m.weight.data, 1.0 , 0.02 )
nn.init.constant_(m.bias.data, 0 )
generator.apply(init_gan)
discriminator.apply(init_gan)