Shortcuts

Source code for torch._dynamo

from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .convert_frame import replay
from .eval_frame import (
    assume_constant_result,
    disable,
    explain,
    export,
    optimize,
    optimize_assert,
    OptimizedModule,
    reset_code,
    run,
    skip,
)
from .external_utils import is_compiling
from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count

__all__ = [
    "allow_in_graph",
    "assume_constant_result",
    "disallow_in_graph",
    "graph_break",
    "optimize",
    "optimize_assert",
    "export",
    "explain",
    "run",
    "replay",
    "disable",
    "reset",
    "skip",
    "OptimizedModule",
    "is_compiling",
    "register_backend",
    "list_backends",
]


[docs]def reset(): """Clear all compile caches and restore initial state""" for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen: code = weak_code() if code: reset_code(code) convert_frame.input_codes.clear() convert_frame.output_codes.clear() orig_code_map.clear() guard_failures.clear() resume_execution.ContinueExecutionCache.cache.clear() eval_frame.most_recent_backend = None compilation_metrics.clear() reset_frame_count()
[docs]def allow_in_graph(fn): """ Customize which functions TorchDynamo will include in the generated graph. Similar to `torch.fx.wrap()`. :: torch._dynamo.allow_in_graph(my_custom_function) @torch._dynamo.optimize(...) def fn(a): x = torch.add(x, 1) x = my_custom_function(x) x = torch.add(x, 1) return x fn(...) Will capture a single graph containing `my_custom_function()`. """ if isinstance(fn, (list, tuple)): return [allow_in_graph(x) for x in fn] assert callable(fn), "allow_in_graph expects a callable" allowed_functions._allowed_function_ids.add(id(fn)) allowed_functions._disallowed_function_ids.remove(id(fn)) return fn
[docs]def disallow_in_graph(fn): """ Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on. :: torch._dynamo.disallow_in_graph(torch.sub) @torch._dynamo.optimize(...) def fn(a): x = torch.add(x, 1) x = torch.sub(x, 1) x = torch.add(x, 1) return x fn(...) Will break the graph on `torch.sub`, and give two graphs each with a single `torch.add()` op. """ if isinstance(fn, (list, tuple)): return [disallow_in_graph(x) for x in fn] assert callable(fn), "disallow_in_graph expects a callable" allowed_functions._allowed_function_ids.remove(id(fn)) allowed_functions._disallowed_function_ids.add(id(fn)) return fn
[docs]@disallow_in_graph def graph_break(): """Force a graph break""" pass

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources