torch.export IR Specification¶
Export IR is an intermediate representation (IR) for compilers, which bears similarities to MLIR and TorchScript. It is specifically designed to express the semantics of PyTorch programs. Export IR primarily represents computation in a streamlined list of operations, with limited support for dynamism such as control flows.
To create an Export IR graph, a frontend can be used that soundly captures a
PyTorch program via a trace-specializing mechanism. The resulting Export IR can
then be optimized and executed by a backend. This can be done today through
torch.export.export()
.
The key concepts that will be covered in this document include:
ExportedProgram: the data structure containing the Export IR program
Graph: which consists of a list of nodes.
Nodes: which represents operations, control flow, and metadata stored on this node.
Values are produced and consumed by nodes.
Types are associated with values and nodes.
The size and memory layout of values are also defined.
Assumptions¶
This doc assumes that the audience is sufficiently familiar with PyTorch,
specifically with torch.fx
and its related toolings. Thus it will stop
describing contents present in torch.fx
documentation and paper.
What is Export IR¶
Export IR is a graph-based intermediate representation IR of PyTorch programs.
Export IR is realized on top of torch.fx.Graph
. In other words, all
Export IR graphs are also valid FX graphs, and if interpreted using standard
FX semantics, Export IR can be interpreted soundly. One implication is that an
exported graph can be converted to a valid Python program via standard FX
codegen.
This documentation will primarily focus on highlighting areas where Export IR differs from FX in terms of its strictness, while skipping parts where it shares similarities with FX.
ExportedProgram¶
The top-level Export IR construct is an torch.export.ExportedProgram
class. It bundles the computational graph of a PyTorch model (which is usually a
torch.nn.Module
) with the parameters or weights that this model
consumes.
Some notable attributes of the torch.export.ExportedProgram
class are:
graph_module
(torch.fx.GraphModule
): Data structure containing the flattened computational graph of the PyTorch model. The graph can be directly accessed through ExportedProgram.graph.graph_signature
(torch.export.ExportGraphSignature
): The graph signature, which specifies the parameters and buffer names used and mutated within the graph. Instead of storing parameters and buffers as attributes of the graph, they are lifted as inputs to the graph. The graph_signature is utilized to keep track of additional information on these parameters and buffers.state_dict
(Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
): Data structure containing the parameters and buffers.range_constraints
(Dict[sympy.Symbol, RangeConstraint]
): For programs that are exported with data dependent behavior, the metadata on each node will contain symbolic shapes (which look likes0
,i0
). This attribute maps the symbolic shapes to their lower/upper ranges.equality_constraints
(List[Tuple[InputDim, InputDim]]
): A list of nodes in the graph and dimensions that have the same shape.
Graph¶
An Export IR Graph is a PyTorch program represented in the form of a DAG (directed acyclic graph). Each node in this graph represents a particular computation or operation, and edges of this graph consist of references between nodes.
We can view Graph having this schema:
class Graph:
nodes: List[Node]
In practice, Export IR’s graph is realized as torch.fx.Graph
Python class.
An Export IR graph contains the following nodes (Nodes will be described in more details in the next section):
0 or more nodes of op type
placeholder
0 or more nodes of op type
call_function
exactly 1 node of op type
output
Collorary: The smallest valid Graph will be of one node. i.e. nodes is never empty.
Definition:
The set of placeholder
nodes of a Graph represents the inputs of the
Graph of GraphModule. The output node of a Graph represents the outputs
of the Graph of GraphModule.
Example:
from torch import nn
class MyModule(nn.Module):
def forward(self, x, y):
return x + y
mod = torch._export.export(MyModule())
print(mod.graph)
The above is the textual representation of a Graph, with each line being a node.
Node¶
A Node represents a particular computation or operation and is represented in
Python using the torch.fx.Node
class. Edges between nodes are
represented as direct references to other nodes via the args
property of the
Node class. Using the same FX machinery, we can represent the following
operations that a computational graph typically needs, such as operator calls,
placeholders (aka inputs), conditionals, and loops.
The Node has the following schema:
class Node:
name: str # name of node
op_name: str # type of operation
# interpretation of the fields below depends on op_name
target: [str|Callable]
args: List[object]
kwargs: Dict[str, object]
meta: Dict[str, object]
FX Text Format
As in the example above, notice that each line has this format:
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
This format captures everything present in the Node class, with the exception of
meta
, in a compact format.
Concretely:
<name> is the name of the node as it would appear in
node.name
.<op_name> is the
node.op
field, which must be one of these: <call_function>, <placeholder>, <get_attr>, or <output>.<target> is the target of the node as
node.target
. The meaning of this field depends onop_name
.args1, … args 4… are what is listed in the
node.args
tuple. If a value in the list is antorch.fx.Node
, then it will be especially indicated with a leading %.
For example, a call to the add operator would appear as:
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
Where %x
, %y
are two other Nodes that have names x and y. Worth noting
that the string torch.op.aten.add.Tensor
represents the callable object that
is actually stored in the target field, not merely its string name.
The final line of this text format is:
return [add]
which is a Node with op_name = output
, indicating that we are returning this
one element.
call_function¶
A call_function
node represents a call to an operator.
Definitions
Functional: We say a callable is “functional” if it satisfies all the following requirements:
Non-mutating: The operator does not mutate the value of its input (for tensors, this includes both metadata and data).
No side effects: The operator does not mutate states that are visible from outside, like changing values of module parameters.
Operator: is a functional callable with a predefined schema. Examples of such operators include functional ATen operators.
Representation in FX
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
Differences from vanilla FX call_function
In FX graph, a call_function can refer to any callable, in Export IR, we restrict it to only a select subset of ATen operators, custom operators, and control flow operators.
In Export IR, constant arguments will be embedded within the graph.
In FX graph, a get_attr node can represent reading any attribute stored in the graph module. However, in Export IR this is restricted to readign only submodules as all parameters/buffers will be passed in as inputs to the graph module.
Metadata¶
Node.meta
is a dict attached to every FX node. However, the FX spec does not
specify what metadata can or will be there. Export IR provides a stronger
contract, specifically all call_function
nodes will guarantee having and
only having the following metadata fields:
node.meta["stack_trace"]
is a string containing the Python stack trace referencing the original Python source code. An example stack trace looks like:File "my_module.py", line 19, in forward return x + dummy_helper(y) File "helper_utility.py", line 89, in dummy_helper return y + 1
node.meta["val"]
describes the output of running the operation. It can be of type <symint>, <FakeTensor>, aList[Union[FakeTensor, SymInt]]
, orNone
.node.meta["nn_module_stack"]
describes the “stacktrace” of thetorch.nn.Module
from which the node came, if it was from atorch.nn.Module
call. For example, if a node containing theaddmm
op called from atorch.nn.Linear
module inside of atorch.nn.Sequential
module, thenn_module_stack
would look something like:{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
node.meta["source_fn_stack"]
contains the torch function or the leaftorch.nn.Module
class this node was called from before decomposition. For example, a node containing theaddmm
op from atorch.nn.Linear
module call would containtorch.nn.Linear
in theirsource_fn
, and a node containing theaddmm
op from atorch.nn.functional.Linear
module call would containtorch.nn.functional.Linear
in theirsource_fn
.
placeholder¶
Placeholder represents an input to a graph. Its semantics are exactly the same as in FX. Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero.
Representation in FX
%name = placeholder[target = name](args = ())
The target field is a string which is the name of input.
args
, if non-empty, should be of size 1 representing the default value of this input.
Metadata
Placeholder nodes also have meta[‘val’]
, like call_function
nodes. The
val
field in this case represents the input shape/dtype that the graph is
expected to receive for this input parameter.
output¶
An output call represents a return statement in a function; it thus terminates the current graph. There is one and only one output node, and it will always be the last node of the graph.
Representation in FX
output[](args = (%something, …))
This has the exact semantics as in torch.fx
. args
represents the node
to be returned.
Metadata
Output node has the same metadata as call_function
nodes.
get_attr¶
get_attr
nodes represent reading a submodule from the encapsulating
torch.fx.GraphModule
. Unlike a vanilla FX graph from
torch.fx.symbolic_trace()
in which get_attr
nodes are used to read
attributes such as parameters and buffers from the top-level
torch.fx.GraphModule
, parameters and buffers are passed in as
inputs to the graph module, and stored in the top-level
torch.export.ExportedProgram
.
Representation in FX
%name = get_attr[target = name](args = ())
Example
Consider the following model:
from functorch.experimental.control_flow import cond
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
Graph:
graph():
%x_1 : [num_users=1] = placeholder[target=x_1]
%y_1 : [num_users=1] = placeholder[target=y_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
return conditional
The line, %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
,
reads the submodule true_graph_0
which contains the sin
operator.
References¶
SymInt¶
A SymInt is an object that can either be a literal integer or a symbol that represents
an Integer (represented in Python by sympy.Symbol
class). When SymInt is a
symbol, it describes a variable of type integer that is unknown to the graph at
compile time, that is, its value is only known at runtime.
FakeTensor¶
A FakeTensor is an object that contains the metadata of a tensor. It can be viewed as having the following metadata.
class FakeTensor:
size: List[SymInt]
dtype: torch.dtype
device: torch.device
dim_order: List[int] # This doesn't exist yet
The size field of FakeTensor is a list of integers or SymInts. If SymInts are present, this means this tensor has a dynamic shape. If integers are present, it is assumed that the tensor will have that exact static shape. The rank of the TensorMeta is never dynamic. The dtype field represents the dtype of the output of that node. There are no implicit type promotions in Edge IR. There are no strides in FakeTensor.
In other words:
If the operator in node.target returns a Tensor, then
node.meta['val']
is a FakeTensor describing that tensor.If the operator in node.target returns an n-tuple of Tensors, then
node.meta['val']
is an n-tuple of FakeTensors describing each tensor.If the operator in node.target returns an int/float/scalar that is known at compile time, then
node.meta['val']
is None.If the operator in node.target returns an int/float/scalar that is not known at compile time, then
node.meta['val']
is of type SymInt.
For example:
aten::add
returns a Tensor; so its spec will be a FakeTensor with dtype and size of the tensor returned by this operator.aten::sym_size
returns an integer; so its val will be a SymInt because its value is only available at runtime.max_pool2d_with_indexes
returns a tuple of (Tensor, Tensor); so the spec will also be a 2-tuple of FakeTensor objects, the first TensorMeta describes the first element of the return value etc.
Python code:
def add_one(x):
return torch.ops.aten(x, 1)
Graph:
graph():
%ph_0 : [#users=1] = placeholder[target=ph_0]
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
return [add_tensor]
FakeTensor:
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
Pytree-able Types¶
We define a type “Pytree-able”, if it is either a leaf type or a container type that contains other Pytree-able types.
Note:
The concept of pytree is the same as the one documented here for JAX:
The following types are defined as leaf type:
Type |
Definition |
---|---|
Tensor |
|
Scalar |
Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors. |
int |
Python int (binded as int64_t in C++) |
float |
Python float (binded as double in C++) |
bool |
Python bool |
str |
Python string |
ScalarType |
|
Layout |
|
MemoryFormat |
|
Device |
The following types are defined as container type:
Type |
Definition |
---|---|
Tuple |
Python tuple |
List |
Python list |
Dict |
Python dict with Scalar keys |
NamedTuple |
Python namedtuple |
Dataclass |
Must be registered through register_dataclass |
Custom class |
Any custom class defined with _register_pytree_node |