Source code for torch._dynamo
from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .convert_frame import replay
from .eval_frame import (
assume_constant_result,
disable,
explain,
export,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
run,
skip,
)
from .external_utils import is_compiling
from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count
__all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"graph_break",
"optimize",
"optimize_assert",
"export",
"explain",
"run",
"replay",
"disable",
"reset",
"skip",
"OptimizedModule",
"is_compiling",
"register_backend",
"list_backends",
]
[docs]def reset():
"""Clear all compile caches and restore initial state"""
for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
code = weak_code()
if code:
reset_code(code)
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
resume_execution.ContinueExecutionCache.cache.clear()
eval_frame.most_recent_backend = None
compilation_metrics.clear()
reset_frame_count()
[docs]def allow_in_graph(fn):
"""
Customize which functions TorchDynamo will include in the generated
graph. Similar to `torch.fx.wrap()`.
::
torch._dynamo.allow_in_graph(my_custom_function)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
return x
fn(...)
Will capture a single graph containing `my_custom_function()`.
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
allowed_functions._allowed_function_ids.add(id(fn))
allowed_functions._disallowed_function_ids.remove(id(fn))
return fn
[docs]def disallow_in_graph(fn):
"""
Customize which functions TorchDynamo will exclude in the generated
graph and force a graph break on.
::
torch._dynamo.disallow_in_graph(torch.sub)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
return x
fn(...)
Will break the graph on `torch.sub`, and give two graphs each with a
single `torch.add()` op.
"""
if isinstance(fn, (list, tuple)):
return [disallow_in_graph(x) for x in fn]
assert callable(fn), "disallow_in_graph expects a callable"
allowed_functions._allowed_function_ids.remove(id(fn))
allowed_functions._disallowed_function_ids.add(id(fn))
return fn