Source code for torch.backends.opt_einsum
from typing import Any
import warnings
import sys
from functools import lru_cache as _lru_cache
from contextlib import contextmanager
from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation
try:
import opt_einsum as _opt_einsum # type: ignore[import]
except ImportError:
_opt_einsum = None
[docs]@_lru_cache()
def is_available() -> bool:
r"""Returns a bool indicating if opt_einsum is currently available."""
return _opt_einsum is not None
[docs]def get_opt_einsum() -> Any:
r"""Returns the opt_einsum package if opt_einsum is currently available, else None."""
return _opt_einsum
def _set_enabled(_enabled: bool) -> None:
if not is_available() and _enabled:
raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap '
'the benefits of calculating an optimal path for einsum. torch.einsum will '
'fall back to contracting from left to right. To enable this optimal path '
'calculation, please install opt-einsum.')
global enabled
enabled = _enabled
def _get_enabled() -> bool:
return enabled
def _set_strategy(_strategy: str) -> None:
if not is_available():
raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please install opt_einsum or unset `strategy`.')
if not enabled:
raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. '
'torch.einsum will bypass path calculation and simply contract from left to right. '
'Please set `enabled` to `True` as well or unset `strategy`.')
if _strategy not in ['auto', 'greedy', 'optimal']:
raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}')
global strategy
strategy = _strategy
def _get_strategy() -> str:
return strategy
def set_flags(_enabled=None, _strategy=None):
orig_flags = (enabled, None if not is_available() else strategy)
if _enabled is not None:
_set_enabled(_enabled)
if _strategy is not None:
_set_strategy(_strategy)
return orig_flags
@contextmanager
def flags(enabled=None, strategy=None):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, strategy)
try:
yield
finally:
# recover the previous values
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
# The magic here is to allow us to intercept code like this:
#
# torch.backends.opt_einsum.enabled = True
class OptEinsumModule(PropModule):
def __init__(self, m, name):
super(OptEinsumModule, self).__init__(m, name)
global enabled
enabled = ContextProp(_get_enabled, _set_enabled)
global strategy
strategy = None
if is_available():
strategy = ContextProp(_get_strategy, _set_strategy)
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
enabled = True if is_available() else False
strategy = 'auto' if is_available() else None