Skip to main content

Overview

The PyTorch C++ Frontend provides a high-level API for building and training neural networks, closely mirroring the Python torch.nn.Module interface. It offers:
  • Modular Architecture: Hierarchical composition of network layers
  • Automatic Parameter Management: Automatic registration and tracking of parameters
  • State Management: Training/evaluation mode switching
  • Serialization: Save and load model checkpoints
  • Built-in Layers: Extensive standard library of common layers

Creating Modules

Basic Module Structure

Modules are created by inheriting from torch::nn::Module:
#include <torch/torch.h>

struct Net : torch::nn::Module {
  Net() {
    // Register submodules in constructor
    fc1 = register_module("fc1", torch::nn::Linear(784, 128));
    fc2 = register_module("fc2", torch::nn::Linear(128, 64));
    fc3 = register_module("fc3", torch::nn::Linear(64, 10));
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(fc1->forward(x));
    x = torch::relu(fc2->forward(x));
    x = fc3->forward(x);
    return x;
  }

  torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};
Always use register_module() to register submodules. This ensures they’re tracked for parameter iteration, serialization, and device movement.

Using Module Holders

The recommended pattern uses module holders for automatic pointer management:
struct NetImpl : torch::nn::Module {
  NetImpl() {
    fc1 = register_module("fc1", torch::nn::Linear(784, 128));
    fc2 = register_module("fc2", torch::nn::Linear(128, 10));
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(fc1->forward(x));
    x = fc2->forward(x);
    return x;
  }

  torch::nn::Linear fc1{nullptr}, fc2{nullptr};
};

TORCH_MODULE(Net);

// Usage
auto net = Net();
auto output = net->forward(input);
The TORCH_MODULE macro creates a ModuleHolder that behaves like a smart pointer.

Registering Parameters

Manual Parameter Registration

struct CustomModule : torch::nn::Module {
  CustomModule() {
    // Register learnable parameters
    weight = register_parameter("weight", 
        torch::randn({10, 10}, torch::requires_grad()));
    bias = register_parameter("bias", 
        torch::zeros({10}, torch::requires_grad()));
  }

  torch::Tensor forward(torch::Tensor x) {
    return torch::addmm(bias, x, weight);
  }

  torch::Tensor weight, bias;
};

Registering Buffers (Non-trainable)

struct ModuleWithBuffer : torch::nn::Module {
  ModuleWithBuffer() {
    // Buffers are saved/loaded but not trained
    register_buffer("running_mean", torch::zeros({10}));
    register_buffer("running_var", torch::ones({10}));
  }

  torch::Tensor running_mean, running_var;
};
Use register_buffer() for tensors that should be saved/loaded with the module but don’t require gradients (e.g., running statistics in batch normalization).

Built-in Layers

Linear Layers

// Fully connected layer
auto fc = torch::nn::Linear(128, 64);

// With options
auto fc_no_bias = torch::nn::Linear(
    torch::nn::LinearOptions(128, 64).bias(false)
);

// Usage
torch::Tensor x = torch::randn({32, 128});
torch::Tensor y = fc->forward(x);  // Shape: {32, 64}

Convolutional Layers

// 2D Convolution
auto conv = torch::nn::Conv2d(
    torch::nn::Conv2dOptions(3, 64, 3)  // in_channels, out_channels, kernel_size
        .stride(1)
        .padding(1)
);

torch::Tensor x = torch::randn({1, 3, 32, 32});
torch::Tensor y = conv->forward(x);  // Shape: {1, 64, 32, 32}

// 1D Convolution (for sequences)
auto conv1d = torch::nn::Conv1d(16, 32, 3);

// 3D Convolution (for videos)
auto conv3d = torch::nn::Conv3d(3, 64, 3);

Pooling Layers

// Max pooling
auto maxpool = torch::nn::MaxPool2d(
    torch::nn::MaxPool2dOptions(2).stride(2)
);

// Average pooling
auto avgpool = torch::nn::AvgPool2d(2);

// Adaptive pooling (output size independent of input size)
auto adaptive_pool = torch::nn::AdaptiveAvgPool2d(
    torch::nn::AdaptiveAvgPool2dOptions({7, 7})
);

Normalization Layers

// Batch normalization
auto bn = torch::nn::BatchNorm2d(64);

// Layer normalization
auto ln = torch::nn::LayerNorm(
    torch::nn::LayerNormOptions({128})
);

// Instance normalization
auto in = torch::nn::InstanceNorm2d(64);

// Group normalization
auto gn = torch::nn::GroupNorm(
    torch::nn::GroupNormOptions(8, 64)  // num_groups, num_channels
);

Activation Functions

// ReLU
auto relu = torch::nn::ReLU();

// LeakyReLU
auto leaky_relu = torch::nn::LeakyReLU(
    torch::nn::LeakyReLUOptions().negative_slope(0.01)
);

// GELU
auto gelu = torch::nn::GELU();

// Sigmoid
auto sigmoid = torch::nn::Sigmoid();

// Tanh
auto tanh = torch::nn::Tanh();

// Softmax
auto softmax = torch::nn::Softmax(
    torch::nn::SoftmaxOptions(/*dim=*/1)
);

Dropout

// Dropout (drops elements during training)
auto dropout = torch::nn::Dropout(
    torch::nn::DropoutOptions().p(0.5)
);

// Dropout2d (drops entire channels)
auto dropout2d = torch::nn::Dropout2d(0.2);

Recurrent Layers

// LSTM
auto lstm = torch::nn::LSTM(
    torch::nn::LSTMOptions(128, 256)  // input_size, hidden_size
        .num_layers(2)
        .batch_first(true)
        .dropout(0.2)
);

// GRU
auto gru = torch::nn::GRU(
    torch::nn::GRUOptions(128, 256)
        .num_layers(2)
);

// RNN
auto rnn = torch::nn::RNN(128, 256);

Container Modules

Sequential

Chain modules in sequential order:
torch::nn::Sequential net(
    torch::nn::Linear(784, 128),
    torch::nn::ReLU(),
    torch::nn::Dropout(0.2),
    torch::nn::Linear(128, 64),
    torch::nn::ReLU(),
    torch::nn::Linear(64, 10)
);

torch::Tensor x = torch::randn({32, 784});
torch::Tensor output = net->forward(x);

ModuleList

Store modules in a list:
struct NetImpl : torch::nn::Module {
  NetImpl() {
    layers = register_module("layers", torch::nn::ModuleList());
    
    layers->push_back(torch::nn::Linear(784, 256));
    layers->push_back(torch::nn::ReLU());
    layers->push_back(torch::nn::Linear(256, 128));
    layers->push_back(torch::nn::ReLU());
    layers->push_back(torch::nn::Linear(128, 10));
  }

  torch::Tensor forward(torch::Tensor x) {
    for (const auto& layer : *layers) {
      x = layer->forward(x);
    }
    return x;
  }

  torch::nn::ModuleList layers{nullptr};
};

TORCH_MODULE(Net);

ModuleDict

Store modules in a dictionary:
struct NetImpl : torch::nn::Module {
  NetImpl() {
    layers = register_module("layers", torch::nn::ModuleDict());
    
    layers->update({
      {"conv1", torch::nn::Conv2d(3, 64, 3)},
      {"conv2", torch::nn::Conv2d(64, 128, 3)},
      {"fc", torch::nn::Linear(128, 10)}
    });
  }

  torch::Tensor forward(torch::Tensor x) {
    x = layers["conv1"]->as<torch::nn::Conv2d>()->forward(x);
    x = layers["conv2"]->as<torch::nn::Conv2d>()->forward(x);
    x = x.view({x.size(0), -1});
    x = layers["fc"]->as<torch::nn::Linear>()->forward(x);
    return x;
  }

  torch::nn::ModuleDict layers{nullptr};
};

TORCH_MODULE(Net);

Training and Evaluation Modes

Switching Modes

auto model = Net();

// Training mode (enables dropout, batch norm training)
model->train();
assert(model->is_training());

// Evaluation mode (disables dropout, uses running stats for BN)
model->eval();
assert(!model->is_training());

Example with Mode Switching

struct NetImpl : torch::nn::Module {
  NetImpl() {
    fc = register_module("fc", torch::nn::Linear(128, 10));
    dropout = register_module("dropout", torch::nn::Dropout(0.5));
    bn = register_module("bn", torch::nn::BatchNorm1d(128));
  }

  torch::Tensor forward(torch::Tensor x) {
    x = bn->forward(x);
    x = dropout->forward(x);  // Only active in training mode
    x = fc->forward(x);
    return x;
  }

  torch::nn::Linear fc{nullptr};
  torch::nn::Dropout dropout{nullptr};
  torch::nn::BatchNorm1d bn{nullptr};
};

TORCH_MODULE(Net);

Parameter Access and Manipulation

Iterating Over Parameters

auto model = Net();

// All parameters
for (const auto& param : model->parameters()) {
  std::cout << "Parameter shape: " << param.sizes() << std::endl;
}

// Named parameters
for (const auto& pair : model->named_parameters()) {
  std::cout << "Name: " << pair.key() 
            << ", Shape: " << pair.value().sizes() << std::endl;
}

Zero Gradients

// Zero all gradients
model->zero_grad();

// Zero specific parameter gradients
for (auto& param : model->parameters()) {
  if (param.grad().defined()) {
    param.mutable_grad().zero_();
  }
}

Moving to Device

auto model = Net();

// Move to GPU
model->to(torch::kCUDA);

// Move to specific device
model->to(torch::Device(torch::kCUDA, 0));

// Move to CPU
model->to(torch::kCPU);

// Change dtype and device
model->to(torch::kFloat64);

Complete Training Example

#include <torch/torch.h>

struct NetImpl : torch::nn::Module {
  NetImpl() {
    fc1 = register_module("fc1", torch::nn::Linear(784, 128));
    fc2 = register_module("fc2", torch::nn::Linear(128, 64));
    fc3 = register_module("fc3", torch::nn::Linear(64, 10));
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(fc1->forward(x));
    x = torch::relu(fc2->forward(x));
    x = fc3->forward(x);
    return torch::log_softmax(x, /*dim=*/1);
  }

  torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};

TORCH_MODULE(Net);

int main() {
  // Create model
  auto model = Net();
  model->to(torch::kCUDA);

  // Create optimizer
  torch::optim::Adam optimizer(
      model->parameters(), 
      torch::optim::AdamOptions(0.001)
  );

  // Training loop
  model->train();
  for (int epoch = 0; epoch < 10; epoch++) {
    // Forward pass
    auto input = torch::randn({32, 784}).to(torch::kCUDA);
    auto target = torch::randint(0, 10, {32}).to(torch::kCUDA);
    auto output = model->forward(input);

    // Compute loss
    auto loss = torch::nll_loss(output, target);

    // Backward pass
    optimizer.zero_grad();
    loss.backward();
    optimizer.step();

    if (epoch % 1 == 0) {
      std::cout << "Epoch " << epoch 
                << ", Loss: " << loss.item<float>() << std::endl;
    }
  }

  // Evaluation
  model->eval();
  {
    torch::NoGradGuard no_grad;
    auto test_input = torch::randn({10, 784}).to(torch::kCUDA);
    auto test_output = model->forward(test_input);
    std::cout << "Test output:\n" << test_output << std::endl;
  }

  return 0;
}

Serialization

Saving Models

auto model = Net();

// Save entire model
torch::save(model, "model.pt");

// Save only parameters
torch::serialize::OutputArchive output_archive;
model->save(output_archive);
output_archive.save_to("model_params.pt");

Loading Models

// Load entire model
auto model = Net();
torch::load(model, "model.pt");

// Load only parameters
auto model = Net();
torch::serialize::InputArchive input_archive;
input_archive.load_from("model_params.pt");
model->load(input_archive);
When loading models, ensure the model architecture definition matches the saved model exactly.

Advanced Patterns

Residual Connections

struct ResidualBlockImpl : torch::nn::Module {
  ResidualBlockImpl(int channels) {
    conv1 = register_module("conv1", 
        torch::nn::Conv2d(channels, channels, 3).padding(1));
    bn1 = register_module("bn1", 
        torch::nn::BatchNorm2d(channels));
    conv2 = register_module("conv2", 
        torch::nn::Conv2d(channels, channels, 3).padding(1));
    bn2 = register_module("bn2", 
        torch::nn::BatchNorm2d(channels));
  }

  torch::Tensor forward(torch::Tensor x) {
    auto residual = x;
    auto out = torch::relu(bn1->forward(conv1->forward(x)));
    out = bn2->forward(conv2->forward(out));
    out += residual;
    return torch::relu(out);
  }

  torch::nn::Conv2d conv1{nullptr}, conv2{nullptr};
  torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr};
};

TORCH_MODULE(ResidualBlock);

Custom Initialization

struct NetImpl : torch::nn::Module {
  NetImpl() {
    fc = register_module("fc", torch::nn::Linear(128, 64));
    
    // Custom weight initialization
    torch::nn::init::xavier_uniform_(fc->weight);
    torch::nn::init::constant_(fc->bias, 0.0);
  }

  torch::Tensor forward(torch::Tensor x) {
    return fc->forward(x);
  }

  torch::nn::Linear fc{nullptr};
};

TORCH_MODULE(Net);

Best Practices

1

Always Register Submodules

Use register_module() for all nested modules to ensure proper parameter tracking.
2

Use TORCH_MODULE Macro

Prefer the TORCH_MODULE pattern for cleaner, more maintainable code.
3

Switch Modes Appropriately

Call train() before training and eval() before evaluation.
4

Move Model and Data to Same Device

Ensure model and input tensors are on the same device (CPU/GPU).
5

Use NoGradGuard During Inference

Wrap inference code in torch::NoGradGuard or torch::InferenceMode for better performance.

Next Steps