Shortcuts

torch.export

Warning

This feature is a prototype under active development and there WILL BE BREAKING CHANGES in the future.

Overview

torch.export.export() takes an arbitrary Python callable (a torch.nn.Module, a function or a method) and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized.

import torch
from torch.export import export

def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    f, args=example_args
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
            # code: a = torch.sin(x)
            sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);

            # code: b = torch.cos(y)
            cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
            return (add,)

    Graph signature: ExportGraphSignature(
        parameters=[],
        buffers=[],
        user_inputs=['arg0_1', 'arg1_1'],
        user_outputs=['add'],
        inputs_to_parameters={},
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}
    Equality constraints: []

torch.export produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be found here.

  • Soundness: It is guaranteed to be a sound representation of the original program, and maintains the same calling conventions of the original program.

  • Normalized: There are no Python semantics within the graph. Submodules from the original programs are inlined to form one fully flattened computational graph.

  • Defined Operator Set: The graph produced contains only a small defined Core ATen IR opset and registered custom operators.

  • Graph properties: The graph is purely functional, meaning it does not contain operations with side effects such as mutations or aliasing. It does not mutate any intermediate values, parameters, or buffers.

  • Metadata: The graph contains metadata captured during tracing, such as a stacktrace from user’s code.

Under the hood, torch.export leverages the following latest technologies:

  • TorchDynamo (torch._dynamo) is an internal API that uses a CPython feature called the Frame Evaluation API to safely trace PyTorch graphs. This provides a massively improved graph capturing experience, with much fewer rewrites needed in order to fully trace the PyTorch code.

  • AOT Autograd provides a functionalized PyTorch graph and ensures the graph is decomposed/lowered to the small defined Core ATen operator set.

  • Torch FX (torch.fx) is the underlying representation of the graph, allowing flexible Python-based transformations.

Existing frameworks

torch.compile() also utilizes the same PT2 stack as torch.export, but is slightly different:

  • JIT vs. AOT: torch.compile() is a JIT compiler whereas which is not intended to be used to produce compiled artifacts outside of deployment.

  • Partial vs. Full Graph Capture: When torch.compile() runs into an untraceable part of a model, it will “graph break” and fall back to running the program in the eager Python runtime. In comparison, torch.export aims to get a full graph representation of a PyTorch model, so it will error out when something untraceable is reached. Since torch.export produces a full graph disjoint from any Python features or runtime, this graph can then be saved, loaded, and run in different environments and languages.

  • Usability tradeoff: Since torch.compile() is able to fallback to the Python runtime whenever it reaches something untraceable, it is a lot more flexible. torch.export will instead require users to provide more information or rewrite their code to make it traceable.

Compared to torch.fx.symbolic_trace(), torch.export traces using TorchDynamo which operates at the Python bytecode level, giving it the ability to trace arbitrary Python constructs not limited by what Python operator overloading supports. Additionally, torch.export keeps fine-grained track of tensor metadata, so that conditionals on things like tensor shapes do not fail tracing. In general, torch.export is expected to work on more user programs, and produce lower-level graphs (at the torch.ops.aten operator level). Note that users can still use torch.fx.symbolic_trace() as a preprocessing step before torch.export.

Compared to torch.jit.script(), torch.export does not capture Python control flow or data structures, but it supports more Python language features than TorchScript (as it is easier to have comprehensive coverage over Python bytecodes). The resulting graphs are simpler and only have straight line control flow (except for explicit control flow operators).

Compared to torch.jit.trace(), torch.export is sound: it is able to trace code that performs integer computation on sizes and records all of the side-conditions necessary to show that a particular trace is valid for other inputs.

Exporting a PyTorch Model

An Example

The main entrypoint is through torch.export.export(), which takes a callable (torch.nn.Module, function, or method) and sample inputs, and captures the computation graph into an torch.export.ExportedProgram. An example:

import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):

            # code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
                arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
            );

            # code: a.add_(constant)
            add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);

            # code: return self.maxpool(self.relu(a))
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
                relu, [3, 3], [3, 3]
            );
            getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
            return (getitem,)

    Graph signature: ExportGraphSignature(
        parameters=['L__self___conv.weight', 'L__self___conv.bias'],
        buffers=[],
        user_inputs=['arg2_1', 'arg3_1'],
        user_outputs=['getitem'],
        inputs_to_parameters={
            'arg0_1': 'L__self___conv.weight',
            'arg1_1': 'L__self___conv.bias',
        },
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}
    Equality constraints: []

Inspecting the ExportedProgram, we can note the following:

  • The torch.fx.Graph contains the computation graph of the original program, along with records of the original code for easy debugging.

  • The graph contains only torch.ops.aten operators found in the Core ATen IR opset and custom operators, and is fully functional, without any inplace operators such as torch.add_.

  • The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no get_attr nodes in the graph, which previously existed in the result of torch.fx.symbolic_trace().

  • The torch.export.ExportGraphSignature models the input and output signature, along with specifying which inputs are parameters.

  • The resulting shape and dtype of tensors produced by each node in the graph is noted. For example, the convolution node will result in a tensor of dtype torch.float32 and shape (1, 16, 256, 256).

Expressing Dynamism

By default torch.export will trace the program assuming all input shapes are static, and specializing the exported program to those dimensions. However, some dimensions, such as a batch dimension, can be dynamic and vary from run to run. Such dimensions must be specified by using the torch.export.Dim() API to create them and by passing them into torch.export.export() through the dynamic_shapes argument. An example:

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):

            # code: out1 = self.branch1(x1)
            permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
            addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
            relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);

            # code: out2 = self.branch2(x2)
            permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
            addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
            relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None

            # code: return (out1 + self.buffer, out2)
            add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
            return (add, relu_1)

    Graph signature: ExportGraphSignature(
        parameters=[
            'branch1.0.weight',
            'branch1.0.bias',
            'branch2.0.weight',
            'branch2.0.bias',
        ],
        buffers=['L__self___buffer'],
        user_inputs=['arg5_1', 'arg6_1'],
        user_outputs=['add', 'relu_1'],
        inputs_to_parameters={
            'arg0_1': 'branch1.0.weight',
            'arg1_1': 'branch1.0.bias',
            'arg2_1': 'branch2.0.weight',
            'arg3_1': 'branch2.0.bias',
        },
        inputs_to_buffers={'arg4_1': 'L__self___buffer'},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
    Equality constraints: [(InputDim(input_name='arg5_1', dim=0), InputDim(input_name='arg6_1', dim=0))]

Some additional things to note:

  • Through the torch.export.Dim() API and the dynamic_shapes argument, we specified the first dimension of each input to be dynamic. Looking at the inputs arg5_1 and arg6_1, they have a symbolic shape of (s0, 64) and (s0, 128), instead of the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. s0 is a symbol representing that this dimension can be a range of values.

  • exported_program.range_constraints describes the ranges of each symbol appearing in the graph. In this case, we see that s0 has the range [2, inf]. For technical reasons that are difficult to explain here, they are assumed to be not 0 or 1. This is not a bug, and does not necessarily mean that the exported program will not work for dimensions 0 or 1. See The 0/1 Specialization Problem for an in-depth discussion of this topic.

  • exported_program.equality_constraints describes which dimensions are required to be equal. Since we specified in the constraints that the first dimension of each argument is equivalent, (dynamic_dim(example_args[0], 0) == dynamic_dim(example_args[1], 0)), we see in the equality constraints the tuple specifying that arg5_1 dimension 0 and arg6_1 dimension 0 are equal.

(A legacy mechanism for specifying dynamic shapes involves marking and constraining dynamic dimensions with the torch.export.dynamic_dim() API and passing them into torch.export.export() through the constraints argument. That mechanism is now deprecated and will not be supported in the future.)

Serialization

To save the ExportedProgram, users can use the torch.export.save() and torch.export.load() APIs. A convention is to save the ExportedProgram using a .pt2 file extension.

An example:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

Specialization

Input shapes

As mentioned before, by default, torch.export will trace the program specializing on the input tensors’ shapes, unless a dimension is specified as dynamic via the torch.export.dynamic_dim() API. This means that if there exists shape-dependent control flow, torch.export will specialize on the branch that is being taken with the given sample inputs. For example:

import torch
from torch.export import export

def fn(x):
    if x.shape[0] > 5:
        return x + 1
    else:
        return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(fn, example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 2]):
            add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            return (add,)

The conditional of (x.shape[0] > 5) does not appear in the ExportedProgram because the example inputs have the static shape of (10, 2). Since torch.export specializes on the inputs’ static shapes, the else branch (x - 1) will never be reached. To preserve the dynamic branching behavior based on the shape of a tensor in the traced graph, torch.export.dynamic_dim() will need to be used to specify the dimension of the input tensor (x.shape[0]) to be dynamic, and the source code will need to be rewritten.

Non-tensor inputs

torch.export also specializes the traced graph based on the values of inputs that are not torch.Tensor, such as int, float, bool, and str. However, we will likely change this in the near future to not specialize on inputs of primitive types.

For example:

import torch
from torch.export import export

def fn(x: torch.Tensor, const: int, times: int):
    for i in range(times):
        x = x + const
    return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(fn, example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
            add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
            add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
            return (add_2,)

Because integers are specialized, the torch.ops.aten.add.Tensor operations are all computed with the inlined constant 1, rather than arg1_1. Additionally, the times iterator used in the for loop is also “inlined” in the graph through the 3 repeated torch.ops.aten.add.Tensor calls, and the input arg2_1 is never used.

Limitations of torch.export

Graph Breaks

As torch.export is a one-shot process for capturing a computation graph from a PyTorch program, it might ultimately run into untraceable parts of programs as it is nearly impossible to support tracing all PyTorch and Python features. In the case of torch.compile, an unsupported operation will cause a “graph break” and the unsupported operation will be run with default Python evaluation. In contrast, torch.export will require users to provide additional information or rewrite parts of their code to make it traceable. As the tracing is based on TorchDynamo, which evaluates at the Python bytecode level, there will be significantly fewer rewrites required compared to previous tracing frameworks.

When a graph break is encountered, ExportDB is a great resource for learning about the kinds of programs that are supported and unsupported, along with ways to rewrite programs to make them traceable.

Data/Shape-Dependent Control Flow

Graph breaks can also be encountered on data-dependent control flow (if x.shape[0] > 2) when shapes are not being specialized, as a tracing compiler cannot possibly deal with without generating code for a combinatorially exploding number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support torch.cond to express if-else like control flow (more coming soon!).

Missing Meta Kernels for Operators

When tracing, a META implementation (or “meta kernel”) is required for all operators. This is used to reason about the input/output shapes for this operator.

To register a meta kernel for a C++ Custom Operator, please refer to this documentation.

The official API for registering custom meta kernels for custom ops implemented in python is currently undergoing development. While the final API is being refined, you can refer to the documentation here.

In the unfortunate case where your model uses an ATen operator that is does not have a meta kernel implementation yet, please file an issue.

API Reference

torch.export.export(f, args, kwargs=None, *, constraints=None, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source]

export() takes an arbitrary Python callable (an nn.Module, a function or a method) along with example inputs, and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different inputs or serialized. The traced graph (1) produces normalized operators in the functional ATen operator set (as well as any user-specified custom operators), (2) has eliminated all Python control flow and data structures (with certain exceptions), and (3) records the set of shape constraints needed to show that this normalization and control-flow elimination is sound for future inputs.

Soundness Guarantee

While tracing, export() takes note of shape-related assumptions made by the user program and the underlying PyTorch operator kernels. The output ExportedProgram is considered valid only when these assumptions hold true.

Tracing makes assumptions on the shapes (not values) of input tensors. Such assumptions must be validated at graph capture time for export() to succeed. Specifically:

  • Assumptions on static shapes of input tensors are automatically validated without additional effort.

  • Assumptions on dynamic shape of input tensors require explicit specification by using the Dim() API to construct dynamic dimensions and by associating them with example inputs through the dynamic_shapes argument.

If any assumption can not be validated, a fatal error will be raised. When that happens, the error message will include suggested fixes to the specification that are needed to validate the assumptions. For example export() might suggest the following fix to the definition of a dynamic dimension dim0_x, say appearing in the shape associated with input x, that was previously defined as Dim("dim0_x"):

dim = Dim("dim0_x", max=5)

This example means the generated code requires dimension 0 of input x to be less than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension definitions and then copy them verbatim into your code without needing to change the dynamic_shapes argument to your export() call.

Parameters
  • f (Callable) – The callable to trace.

  • args (Tuple[Any, ...]) – Example positional inputs.

  • kwargs (Optional[Dict[str, Any]]) – Optional example keyword inputs.

  • constraints (Optional[List[Constraint]]) – [DEPRECATED: use dynamic_shapes instead, see below] An optional list of constraints on the dynamic arguments that specify their possible range of shapes. By default, shapes of input torch.Tensors are assumed to be static. If an input torch.Tensor is expected to have dynamic shapes, please use dynamic_dim() to define Constraint objects that specify the dynamics and the possible range of shapes. See dynamic_dim() docstring for examples on how to use it.

  • dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any]]]) –

    Should either be: 1) a dict from argument names of f to their dynamic shape specifications, 2) a tuple that specifies dynamic shape specifications for each input in original order. If you are specifying dynamism on keyword args, you will need to pass them in the order that is defined in the original function signature.

    The dynamic shape of a tensor argument can be specified as either (1) a dict from dynamic dimension indices to Dim() types, where it is not required to include static dimension indices in this dict, but when they are, they should be mapped to None; or (2) a tuple / list of Dim() types or None, where the Dim() types correspond to dynamic dimensions, and static dimensions are denoted by None. Arguments that are dicts or tuples / lists of tensors are recursively specified by using mappings or sequences of contained specifications.

  • strict (bool) – When enabled (default), the export function will trace the program through TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the exported program will not validate the implicit assumptions baked into the graph and may cause behavior divergence between the original model and the exported one. This is useful when users need to workaround bugs in the tracer, or simply want incrementally enable safety in their models. Note that this does not affect the resulting IR spec to be different and the model will be serialized in the same way regardless of what value is passed here. WARNING: This option is experimental and use this at your own risk.

Returns

An ExportedProgram containing the traced callable.

Return type

ExportedProgram

Acceptable input/output types

Acceptable types of inputs (for args and kwargs) and outputs include:

  • Primitive types, i.e. torch.Tensor, int, float, bool and str.

  • Dataclasses, but they must be registered by calling register_dataclass() first.

  • (Nested) Data structures comprising of dict, list, tuple, namedtuple and OrderedDict containing all above types.

torch.export.dynamic_dim(t, index)[source]

Warning

(This feature is DEPRECATED. See Dim() instead.)

dynamic_dim() constructs a Constraint object that describes the dynamism of a dimension index of tensor t. Constraint objects should be passed to constraints argument of export().

Parameters
  • t (torch.Tensor) – Example input tensor that have dynamic dimension size(s)

  • index (int) – Index of dynamic dimension

Returns

A Constraint object that describes shape dynamism. It can be passed to export() so that export() does not assume static size of specified tensor, i.e. keeping it dynamic as a symbolic size rather than specializing according to size of example tracing input.

Specifically dynamic_dim() can be used to express following types of dynamism.

  • Size of a dimension is dynamic and unbounded:

    t0 = torch.rand(2, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size rather than always being static size 2
    constraints = [dynamic_dim(t0, 0)]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • Size of a dimension is dynamic with a lower bound:

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
    # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) >= 5,
        dynamic_dim(t1, 1) > 2,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • Size of a dimension is dynamic with an upper bound:

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
    # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) <= 16,
        dynamic_dim(t1, 1) < 8,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # Sizes of second dimension of t0 and first dimension are always equal
    constraints = [
        dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • Mix and match all types above as long as they do not express conflicting requirements

torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source]

Warning

Under active development, saved files may not be usable in newer versions of PyTorch.

Saves an ExportedProgram to a file-like object. It can then be loaded using the Python API torch.export.load.

Parameters
  • ep (ExportedProgram) – The exported program to save.

  • f (Union[str, pathlib.Path, io.BytesIO) – A file-like object (has to implement write and flush) or a string containing a file name.

  • extra_files (Optional[Dict[str, Any]]) – Map from filename to contents which will be stored as part of f.

  • opset_version (Optional[Dict[str, int]]) – A map of opset names to the version of this opset

Example:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, 'exported_program.pt2')

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]

Warning

Under active development, saved files may not be usable in newer versions of PyTorch.

Loads an ExportedProgram previously saved with torch.export.save.

Parameters
  • ep (ExportedProgram) – The exported program to save.

  • f (Union[str, pathlib.Path, io.BytesIO) – A file-like object (has to implement write and flush) or a string containing a file name.

  • extra_files (Optional[Dict[str, Any]]) – The extra filenames given in this map would be loaded and their content would be stored in the provided map.

  • expected_opset_version (Optional[Dict[str, int]]) – A map of opset names to expected opset versions

Returns

An ExportedProgram object

Return type

ExportedProgram

Example:

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')

# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
print(ep(torch.randn(5)))
torch.export.register_dataclass(cls)[source]

Registers a dataclass as a valid input/output type for torch.export.export().

Parameters

cls (Type[Any]) – the dataclass type to register

Example:

@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int

class OutputDataClass:
    res: torch.Tensor

torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)

def fn(o: InputDataClass) -> torch.Tensor:
    res = res=o.feature + o.bias
    return OutputDataClass(res=res)

ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
torch.export.Dim(name, *, min=None, max=None)[source]

Dim() constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type.

Parameters
  • name (str) – Human-readable name for debugging.

  • min (Optional[int]) – Minimum possible value of given symbol (inclusive)

  • max (Optional[int]) – Maximum possible value of given symbol (inclusive)

Returns

A type that can be used in dynamic shape specifications for tensors.

torch.export.dims(*names, min=None, max=None)[source]

Util to create multiple Dim() types.

class torch.export.Constraint(*args, **kwargs)[source]

Warning

Do not construct Constraint directly, use dynamic_dim() instead.

This represents constraints on input tensor dimensions, e.g., requiring them to be fully polymorphic or within some range.

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, equality_constraints, module_call_graph, example_inputs=None, verifier=None, tensor_constants=None)[source]

Package of a program from export(). It contains an torch.fx.Graph that represents Tensor computation, a state_dict containing tensor values of all lifted parameters and buffers, and various metadata.

You can call an ExportedProgram like the original callable traced by export() with the same calling convention.

To perform transformations on the graph, use .module property to access an torch.fx.GraphModule. You can then use FX transformation to rewrite the graph. Afterwards, you can simply use export() again to construct a correct ExportedProgram.

module(*, flat=True)[source]

Returns a self contained GraphModule with all the parameters/buffers inlined.

Return type

Module

buffers()[source]

Returns an iterator over original module buffers.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Tensor]

named_buffers()[source]

Returns an iterator over original module buffers, yielding both the name of the buffer as well as the buffer itself.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Tuple[str, Tensor]]

parameters()[source]

Returns an iterator over original module’s parameters.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Parameter]

named_parameters()[source]

Returns an iterator over original module parameters, yielding both the name of the parameter as well as the parameter itself.

Warning

This API is experimental and is NOT backward-compatible.

Return type

Iterator[Tuple[str, Parameter]]

class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[source]
class torch.export.ExportGraphSignature(input_specs, output_specs)[source]

ExportGraphSignature models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees.

Export Graph is functional and does not access “states” like parameters or buffers within the graph via getattr nodes. Instead, export() gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.

The ordering of all inputs and outputs are:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

e.g. If following module is exported:

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

Resulting Graph would be:

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

Resulting ExportGraphSignature would be:

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[source]
class torch.export.ModuleCallEntry(fqn: str, signature: Union[torch.export.exported_program.ModuleCallSignature, NoneType] = None)[source]
class torch.export.graph_signature.InputKind(value)[source]

An enumeration.

class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument], target: Union[str, NoneType])[source]
class torch.export.graph_signature.OutputKind(value)[source]

An enumeration.

class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument], target: Union[str, NoneType])[source]
class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]

ExportGraphSignature models the input/output signature of Export Graph, which is a fx.Graph with stronger invariants gurantees.

Export Graph is functional and does not access “states” like parameters or buffers within the graph via getattr nodes. Instead, export() gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph.

The ordering of all inputs and outputs are:

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

e.g. If following module is exported:

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

Resulting Graph would be:

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

Resulting ExportGraphSignature would be:

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
replace_all_uses(old, new)[source]

Replace all uses of the old name with new name in the signature.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources