Skip to main content
Quantization reduces the precision of model weights and activations from 32-bit floating point to lower bit-widths (typically 8-bit integers), dramatically reducing model size and improving inference speed with minimal accuracy loss.

Why Quantization?

Benefits:
  • 4x smaller models: INT8 uses 1/4 the memory of FP32
  • 2-4x faster inference: Integer operations are faster than floating point
  • Lower power consumption: Critical for mobile and edge devices
  • Maintained accuracy: Modern techniques preserve 99%+ of original accuracy

Quantization Types

PyTorch supports three main quantization approaches:

Dynamic Quantization

Quantize weights statically, activations dynamically at runtime:
import torch
from torch.quantization import quantize_dynamic

# Original model
model = MyModel()

# Dynamically quantize
quantized_model = quantize_dynamic(
    model,
    {torch.nn.Linear, torch.nn.LSTM},  # Layers to quantize
    dtype=torch.qint8
)

# Use quantized model
output = quantized_model(input)
Best for: NLP models (LSTM, Transformer), where activations vary significantly between inputs.Pros: Easy to use, no calibration needed
Cons: Activations computed in FP32, then quantized

Static Quantization (Post-Training)

Quantize both weights and activations using calibration data:
1

Prepare Model for Quantization

from torch.quantization import get_default_qconfig, prepare

# Set quantization config
model.qconfig = get_default_qconfig('x86')  # or 'qnnpack' for mobile

# Prepare model (insert observers)
prepared_model = prepare(model, inplace=False)
2

Calibrate with Representative Data

# Run calibration data through model
with torch.no_grad():
    for batch in calibration_data:
        prepared_model(batch)
3

Convert to Quantized Model

from torch.quantization import convert

# Convert to quantized version
quantized_model = convert(prepared_model, inplace=False)

# Now use quantized model for inference
output = quantized_model(input)
Best for: CNN models (ResNet, MobileNet) for computer vision.Pros: Fastest inference, both weights and activations quantized
Cons: Requires calibration data

Quantization Aware Training (QAT)

Train with quantization in mind for best accuracy:
from torch.quantization import get_default_qat_qconfig, prepare_qat, convert

# Start with pretrained model
model = PretrainedModel()

# Set QAT config
model.qconfig = get_default_qat_qconfig('x86')

# Prepare for QAT
model_qat = prepare_qat(model.train())

# Train with fake quantization
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        output = model_qat(batch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Convert to fully quantized model
model_qat.eval()
quantized_model = convert(model_qat, inplace=False)
Best for: Models where post-training quantization causes significant accuracy loss.Pros: Best accuracy retention
Cons: Requires retraining, longer development time

Quantization Configuration

QConfig

Control quantization behavior with QConfig:
from torch.quantization import QConfig, default_observer, default_weight_observer

# Create custom QConfig
my_qconfig = QConfig(
    activation=default_observer,      # How to quantize activations
    weight=default_weight_observer    # How to quantize weights
)

model.qconfig = my_qconfig

Common Observers

from torch.quantization import MinMaxObserver

# Simple min/max tracking
qconfig = QConfig(
    activation=MinMaxObserver.with_args(dtype=torch.quint8),
    weight=MinMaxObserver.with_args(dtype=torch.qint8)
)

Backend-Specific Quantization

Choose backend based on deployment target:

x86 CPU (FBGEMM)

import torch
torch.backends.quantized.engine = 'fbgemm'

model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
Use for: Intel x86 CPUs, servers

ARM CPU (QNNPACK)

import torch
torch.backends.quantized.engine = 'qnnpack'

model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
Use for: Mobile devices (iOS, Android), Raspberry Pi, edge devices

ONNX Runtime

from torch.quantization import quantize_dynamic

quantized_model = quantize_dynamic(model, dtype=torch.qint8)
torch.onnx.export(quantized_model, dummy_input, "model.onnx")
Use for: Cross-platform deployment

FX Graph Mode Quantization

Modern API using FX graph tracing (recommended):
import torch
from torch.ao.quantization import quantize_fx

# Example model
model = models.resnet18(pretrained=True).eval()

# Prepare QConfig mapping
qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping('qnnpack')

# Example input
example_input = torch.randn(1, 3, 224, 224)

# Prepare and calibrate
prepared_model = quantize_fx.prepare_fx(
    model,
    qconfig_mapping,
    example_input
)

# Calibrate
with torch.no_grad():
    for batch in calibration_data:
        prepared_model(batch)

# Convert to quantized model
quantized_model = quantize_fx.convert_fx(prepared_model)
FX Graph Mode is the recommended approach for new projects. It provides better automation and support for modern architectures.

Layer-Specific Quantization

Quantize specific layers differently:
from torch.ao.quantization import QConfigMapping

qconfig_mapping = QConfigMapping()

# Different configs for different layer types
qconfig_mapping.set_global(default_qconfig)
qconfig_mapping.set_object_type(torch.nn.Linear, linear_qconfig)
qconfig_mapping.set_object_type(torch.nn.Conv2d, conv_qconfig)

# Per-module config
qconfig_mapping.set_module_name("features.0", special_qconfig)

Fusing Operations

Fuse operations before quantization for better performance:
from torch.quantization import fuse_modules

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

model = MyModel()

# Fuse Conv+BN+ReLU into single operation
fused_model = fuse_modules(
    model,
    [['conv', 'bn', 'relu']],
    inplace=False
)
# Conv + BatchNorm + ReLU
fuse_modules(model, [['conv', 'bn', 'relu']])

# Linear + ReLU
fuse_modules(model, [['linear', 'relu']])

# Conv + BatchNorm
fuse_modules(model, [['conv', 'bn']])

Quantization Debugging

Compare Accuracy

def compare_models(float_model, quantized_model, test_data):
    float_model.eval()
    quantized_model.eval()
    
    with torch.no_grad():
        for input, target in test_data:
            # Float output
            float_output = float_model(input)
            
            # Quantized output  
            quant_output = quantized_model(input)
            
            # Compare
            diff = (float_output - quant_output).abs().mean()
            print(f"Mean difference: {diff.item():.6f}")

Visualize Quantization

from torch.quantization import get_observer_dict

# After calibration
observer_dict = get_observer_dict(prepared_model)

for name, observer in observer_dict.items():
    print(f"{name}:")
    print(f"  Min: {observer.min_val}")
    print(f"  Max: {observer.max_val}")
    print(f"  Scale: {observer.calculate_qparams()[0]}")

Advanced Techniques

Mixed Precision Quantization

Use different bit-widths for different layers:
from torch.ao.quantization import QConfigMapping

# 8-bit for most layers
default_qconfig = get_default_qconfig('qnnpack')

# 16-bit for sensitive layers
sensitive_qconfig = QConfig(
    activation=MinMaxObserver.with_args(dtype=torch.quint16),
    weight=MinMaxObserver.with_args(dtype=torch.qint16)
)

qconfig_mapping = QConfigMapping()
qconfig_mapping.set_global(default_qconfig)
qconfig_mapping.set_module_name("classifier", sensitive_qconfig)

Knowledge Distillation + Quantization

Combine knowledge distillation with QAT:
# Teacher: Full precision model
teacher_model = PretrainedModel().eval()

# Student: Quantized model
student_model = prepare_qat(model.train())

for batch in dataloader:
    # Teacher predictions
    with torch.no_grad():
        teacher_output = teacher_model(batch)
    
    # Student predictions (quantized)
    student_output = student_model(batch)
    
    # Distillation loss
    loss = distillation_loss(student_output, teacher_output, target)
    loss.backward()
    optimizer.step()

Deployment

Save Quantized Model

# Save quantized model
torch.save(quantized_model.state_dict(), 'quantized_model.pth')

# Save entire model (includes quantization config)
torch.save(quantized_model, 'quantized_model_full.pth')

# Load quantized model
loaded_model = torch.load('quantized_model_full.pth')
loaded_model.eval()

Export to Mobile

from torch.utils.mobile_optimizer import optimize_for_mobile

# Trace quantized model
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(quantized_model, example_input)

# Optimize for mobile
optimized_model = optimize_for_mobile(traced_model)

# Save for mobile
optimized_model._save_for_lite_interpreter('model_quantized.ptl')

Performance Tips

Quantization Checklist:
  • ✓ Fuse operations before quantizing (Conv+BN+ReLU)
  • ✓ Use calibration data representative of real inputs
  • ✓ Try per-channel quantization for better accuracy
  • ✓ Use QAT if post-training quantization loses >1% accuracy
  • ✓ Choose correct backend (fbgemm for x86, qnnpack for ARM)

Benchmark Quantized Models

import time

# Measure latency
input = torch.randn(1, 3, 224, 224)

start = time.time()
for _ in range(100):
    _ = quantized_model(input)
quant_time = time.time() - start

start = time.time()
for _ in range(100):
    _ = float_model(input)
float_time = time.time() - start

print(f"Speedup: {float_time / quant_time:.2f}x")

# Measure model size
import os
torch.save(quantized_model.state_dict(), 'quant.pth')
torch.save(float_model.state_dict(), 'float.pth')

quant_size = os.path.getsize('quant.pth') / 1024 / 1024
float_size = os.path.getsize('float.pth') / 1024 / 1024

print(f"Model size reduction: {float_size / quant_size:.2f}x")

Troubleshooting

Accuracy Drop Issues

# 1. Try per-channel quantization
qconfig = QConfig(
    activation=default_observer,
    weight=default_per_channel_weight_observer
)

# 2. Use more calibration data
for batch in more_calibration_data:  # Use more batches
    prepared_model(batch)

# 3. Try histogram observer
from torch.quantization import HistogramObserver
qconfig = QConfig(
    activation=HistogramObserver.with_args(dtype=torch.quint8),
    weight=default_weight_observer
)

# 4. Switch to QAT
model_qat = prepare_qat(model)
# Retrain with fake quantization