Skip to main content

TorchScript JIT

The torch.jit module provides functionality for creating TorchScript code from PyTorch code. TorchScript is a way to create serializable and optimizable models from PyTorch code.
TorchScript is deprecated as of PyTorch 2.0. Please use torch.compile instead for better performance and ease of use.

Core Functions

script

torch.jit.script(
    obj,
    optimize=None,
    _frames_up=0,
    _rcb=None,
    example_inputs=None
)
Scripts a function or nn.Module, compiling it to TorchScript.
obj
callable or nn.Module
The function or module to be scripted. Can be a function, method, or nn.Module.
example_inputs
tuple
default:"None"
Example inputs to the function or module. Used to infer types.
Returns
ScriptModule or ScriptFunction
The compiled TorchScript code

trace

torch.jit.trace(
    func,
    example_inputs,
    optimize=None,
    check_trace=True,
    check_inputs=None,
    check_tolerance=1e-5,
    strict=True,
    _force_outplace=False,
    _module_class=None,
    _compilation_unit=_python_cu
)
Trace a function and return an executable or ScriptModule that will be optimized using just-in-time compilation.
func
callable or nn.Module
A Python function or torch.nn.Module that will be run with example_inputs.
example_inputs
tuple or Tensor
A tuple of example inputs that will be passed to the function while tracing.
check_trace
bool
default:"True"
Check if the same inputs run through traced code produce the same outputs.
strict
bool
default:"True"
Run the tracer in strict mode, which enforces that the entire computation is traceable.
Returns
ScriptModule or ScriptFunction
The traced code

Optimization Functions

freeze

torch.jit.freeze(
    mod,
    preserved_attrs=None,
    optimize_numerics=True
)
Clones a ScriptModule and attempts to inline the cloned module’s submodules, parameters, and attributes as constants.
mod
ScriptModule
The module to freeze.
preserved_attrs
list[str]
default:"None"
Attributes to preserve in the frozen module.
optimize_numerics
bool
default:"True"
Whether to run optimizations that assume floating point operations are associative.

optimize_for_inference

torch.jit.optimize_for_inference(
    mod,
    other_methods=None
)
Performs a set of optimization passes to optimize a model for inference.
mod
ScriptModule
The module to optimize.
other_methods
list[str]
default:"None"
Other methods to optimize in addition to forward.

Serialization

save

torch.jit.save(m, f, _extra_files=None)
Save a ScriptModule or ScriptFunction to a file.
m
ScriptModule or ScriptFunction
The module or function to save.
f
str or file-like object
A file-like object or a string containing a file name.

load

torch.jit.load(f, map_location=None, _extra_files=None)
Load a ScriptModule or ScriptFunction previously saved with torch.jit.save.
f
str or file-like object
A file-like object or a string containing a file name.
map_location
str, torch.device, or dict
default:"None"
A function, torch.device, string or a dict specifying how to remap storage locations.
Returns
ScriptModule or ScriptFunction
The loaded module or function

Type Annotations

annotate

torch.jit.annotate(the_type, the_value)
Use to give type of the_value in TorchScript compiler.
the_type
type
Python type that should be passed to TorchScript compiler as type hint.
the_value
Any
Value or expression to hint type for.

isinstance

torch.jit.isinstance(obj, target_type)
Provide container type refinement in TorchScript.
obj
Any
Object to refine the type of.
target_type
type
Type to try to refine obj to.

Decorators and Context Managers

ignore

@torch.jit.ignore
def fn(x):
    # This function will not be compiled
    return x.numpy()
Decorator that tells the compiler to ignore a function or method.

unused

@torch.jit.unused
def fn(x):
    # This function can be unused in TorchScript
    return x + 1
Decorator indicating that a function or method may be unused.

export

@torch.jit.export
def fn(x):
    return x * 2
Decorator that marks a method as exported, making it callable from C++.

Async Execution

fork

torch.jit.fork(func, *args, **kwargs)
Creates an asynchronous task executing func and returns a reference to the value of the result of this execution.
func
callable
A Python function or TorchScript function to execute asynchronously.

wait

torch.jit.wait(future)
Forces completion of a torch.jit.Future[T] asynchronous task, returning the result of the task.
future
Future
The future to wait on.

Example Usage

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
    
    def forward(self, x):
        return self.linear(x)

# Script the module
module = MyModule()
scripted_module = torch.jit.script(module)

# Use the scripted module
x = torch.randn(1, 10)
output = scripted_module(x)
print(output)

# Save the scripted module
torch.jit.save(scripted_module, 'model.pt')

Best Practices

  • Use torch.jit.script when your code has control flow (if statements, loops)
  • Use torch.jit.trace for models with consistent execution paths
  • script analyzes Python code directly, while trace records operations during execution
  • Use torch.jit.freeze before deployment to inline constants
  • Apply optimize_for_inference for inference-only models
  • Consider using torch.compile (PyTorch 2.0+) as an alternative
  • Use check_trace=True when tracing to validate correctness
  • Add type annotations for better error messages
  • Use @torch.jit.ignore for code that doesn’t need to be compiled