Skip to main content

Quantization

The torch.quantization module provides tools to convert floating-point models to quantized versions with reduced precision (INT8) for faster inference and smaller model sizes.

Quantization Modes

PyTorch supports three types of quantization:
  1. Dynamic Quantization - Weights are quantized ahead of time, activations are quantized dynamically at runtime
  2. Static Quantization - Weights and activations are quantized based on observed data distributions
  3. Quantization-Aware Training (QAT) - Quantization is simulated during training for better accuracy

Core Functions

quantize

torch.quantization.quantize(
    model,
    run_fn,
    run_args,
    mapping=None,
    inplace=False
)
Quantize a float model using static quantization.
model
nn.Module
Float model to be quantized.
run_fn
callable
Function to run the model on sample data for calibration.
run_args
tuple
Arguments for the calibration function.
inplace
bool
default:"False"
Whether to modify the model in-place.
Returns
nn.Module
Quantized model

quantize_dynamic

torch.quantization.quantize_dynamic(
    model,
    qconfig_spec=None,
    dtype=torch.qint8,
    mapping=None,
    inplace=False
)
Convert a float model to dynamic quantized model.
model
nn.Module
Float model to be quantized.
qconfig_spec
set or dict
default:"None"
Either a set of module types or a dict mapping module types to QConfig.
dtype
torch.dtype
default:"torch.qint8"
Quantized data type for weights. Options: torch.qint8, torch.float16.
inplace
bool
default:"False"
Whether to modify the model in-place.

quantize_qat

torch.quantization.quantize_qat(
    model,
    run_fn,
    run_args,
    inplace=False
)
Perform quantization-aware training and output a quantized model.
model
nn.Module
Float model to be trained with quantization awareness.
run_fn
callable
Function to train the model.
run_args
tuple
Arguments for the training function.

Preparation Functions

prepare

torch.quantization.prepare(
    model,
    inplace=False,
    allow_list=None,
    observer_non_leaf_module_list=None,
    prepare_custom_config_dict=None
)
Prepare a model for static quantization by inserting observers.
model
nn.Module
Float model to be prepared.
inplace
bool
default:"False"
Whether to modify the model in-place.

prepare_qat

torch.quantization.prepare_qat(
    model,
    inplace=False,
    allow_list=None,
    observer_non_leaf_module_list=None,
    prepare_custom_config_dict=None
)
Prepare a model for quantization-aware training by inserting fake quantization modules.
model
nn.Module
Float model to be prepared for QAT.

convert

torch.quantization.convert(
    module,
    mapping=None,
    inplace=False,
    remove_qconfig=True,
    convert_custom_config_dict=None
)
Convert a prepared/calibrated model to a quantized model.
module
nn.Module
Prepared model to be converted.
inplace
bool
default:"False"
Whether to modify the model in-place.
remove_qconfig
bool
default:"True"
Whether to remove qconfig after conversion.

QConfig

QConfig

torch.quantization.QConfig(
    activation,
    weight
)
Configuration for quantization, specifying how to observe activations and weights.
activation
Observer
Observer for activations.
weight
Observer
Observer for weights.

Pre-defined QConfigs

torch.quantization.default_qconfig
torch.quantization.default_dynamic_qconfig
torch.quantization.default_qat_qconfig
torch.quantization.float16_dynamic_qconfig
Pre-configured quantization configurations for common use cases.

Observers

MinMaxObserver

torch.quantization.MinMaxObserver(
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False,
    quant_min=None,
    quant_max=None
)
Observer that records min and max values for quantization scale calculation.
dtype
torch.dtype
default:"torch.quint8"
Quantized data type.
qscheme
torch.qscheme
default:"torch.per_tensor_affine"
Quantization scheme to use.

HistogramObserver

torch.quantization.HistogramObserver(
    bins=2048,
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False
)
Observer that records a histogram of values for optimal quantization parameters.
bins
int
default:"2048"
Number of histogram bins.

Fake Quantization

FakeQuantize

torch.quantization.FakeQuantize(
    observer=MinMaxObserver,
    quant_min=0,
    quant_max=255,
    dtype=torch.quint8,
    qscheme=torch.per_tensor_affine,
    reduce_range=False
)
Simulates quantization during training for quantization-aware training.

Quantization Stubs

QuantStub

torch.quantization.QuantStub(qconfig=None)
Stub module for quantizing inputs.

DeQuantStub

torch.quantization.DeQuantStub()
Stub module for dequantizing outputs.

Fusion

fuse_modules

torch.quantization.fuse_modules(
    model,
    modules_to_fuse,
    inplace=False,
    fuser_func=None
)
Fuses a list of modules into a single module.
model
nn.Module
Model containing modules to fuse.
modules_to_fuse
list
List of module names to fuse, e.g., ['conv', 'bn', 'relu'].
inplace
bool
default:"False"
Whether to modify the model in-place.

Example Usage

import torch
import torch.nn as nn
from torch.quantization import quantize_dynamic

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 5)
    
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

# Create and quantize the model
model = SimpleModel()
quantized_model = quantize_dynamic(
    model,
    {nn.Linear},  # Only quantize Linear layers
    dtype=torch.qint8
)

# Use the quantized model
x = torch.randn(1, 10)
output = quantized_model(x)
print(f"Output shape: {output.shape}")

Best Practices

  • Dynamic Quantization: Use for models with large matrix multiplications (e.g., LSTMs, Transformers). Easy to apply, minimal accuracy loss.
  • Static Quantization: Best for CNNs and models where activation distributions are consistent. Requires calibration data.
  • QAT: Use when static quantization shows accuracy degradation. Provides best accuracy but requires retraining.
  • Fuse Conv-BN-ReLU sequences before quantization for better performance
  • Use per-channel quantization for weights when possible
  • Start with dynamic quantization, move to static/QAT only if needed
  • Test on target hardware as quantized operations may have different performance characteristics
  • Compare outputs between float and quantized models on sample data
  • Use torch.quantization.get_observer_dict() to inspect observer statistics
  • Gradually quantize layers to identify problematic operations
  • Consider using higher precision (e.g., fp16) for sensitive layers