Reduce model size and accelerate inference with quantization techniques
Quantization reduces the precision of model weights and activations from 32-bit floating point to lower bit-widths (typically 8-bit integers), dramatically reducing model size and improving inference speed with minimal accuracy loss.
Quantize weights statically, activations dynamically at runtime:
import torchfrom torch.quantization import quantize_dynamic# Original modelmodel = MyModel()# Dynamically quantizequantized_model = quantize_dynamic( model, {torch.nn.Linear, torch.nn.LSTM}, # Layers to quantize dtype=torch.qint8)# Use quantized modeloutput = quantized_model(input)
Best for: NLP models (LSTM, Transformer), where activations vary significantly between inputs.Pros: Easy to use, no calibration needed Cons: Activations computed in FP32, then quantized
Quantize both weights and activations using calibration data:
1
Prepare Model for Quantization
from torch.quantization import get_default_qconfig, prepare# Set quantization configmodel.qconfig = get_default_qconfig('x86') # or 'qnnpack' for mobile# Prepare model (insert observers)prepared_model = prepare(model, inplace=False)
2
Calibrate with Representative Data
# Run calibration data through modelwith torch.no_grad(): for batch in calibration_data: prepared_model(batch)
3
Convert to Quantized Model
from torch.quantization import convert# Convert to quantized versionquantized_model = convert(prepared_model, inplace=False)# Now use quantized model for inferenceoutput = quantized_model(input)
Best for: CNN models (ResNet, MobileNet) for computer vision.Pros: Fastest inference, both weights and activations quantized Cons: Requires calibration data
Train with quantization in mind for best accuracy:
from torch.quantization import get_default_qat_qconfig, prepare_qat, convert# Start with pretrained modelmodel = PretrainedModel()# Set QAT configmodel.qconfig = get_default_qat_qconfig('x86')# Prepare for QATmodel_qat = prepare_qat(model.train())# Train with fake quantizationfor epoch in range(num_epochs): for batch in dataloader: optimizer.zero_grad() output = model_qat(batch) loss = criterion(output, target) loss.backward() optimizer.step()# Convert to fully quantized modelmodel_qat.eval()quantized_model = convert(model_qat, inplace=False)
Best for: Models where post-training quantization causes significant accuracy loss.Pros: Best accuracy retention Cons: Requires retraining, longer development time
from torch.quantization import QConfig, default_observer, default_weight_observer# Create custom QConfigmy_qconfig = QConfig( activation=default_observer, # How to quantize activations weight=default_weight_observer # How to quantize weights)model.qconfig = my_qconfig
from torch.ao.quantization import QConfigMappingqconfig_mapping = QConfigMapping()# Different configs for different layer typesqconfig_mapping.set_global(default_qconfig)qconfig_mapping.set_object_type(torch.nn.Linear, linear_qconfig)qconfig_mapping.set_object_type(torch.nn.Conv2d, conv_qconfig)# Per-module configqconfig_mapping.set_module_name("features.0", special_qconfig)