Shortcuts

Source code for torch.distributions.constraint_registry

r"""
PyTorch provides two global :class:`ConstraintRegistry` objects that link
:class:`~torch.distributions.constraints.Constraint` objects to
:class:`~torch.distributions.transforms.Transform` objects. These objects both
input constraints and return transforms, but they have different guarantees on
bijectivity.

1. ``biject_to(constraint)`` looks up a bijective
   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
   to the given ``constraint``. The returned transform is guaranteed to have
   ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
2. ``transform_to(constraint)`` looks up a not-necessarily bijective
   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
   to the given ``constraint``. The returned transform is not guaranteed to
   implement ``.log_abs_det_jacobian()``.

The ``transform_to()`` registry is useful for performing unconstrained
optimization on constrained parameters of probability distributions, which are
indicated by each distribution's ``.arg_constraints`` dict. These transforms often
overparameterize a space in order to avoid rotation; they are thus more
suitable for coordinate-wise optimization algorithms like Adam::

    loc = torch.zeros(100, requires_grad=True)
    unconstrained = torch.zeros(100, requires_grad=True)
    scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
    loss = -Normal(loc, scale).log_prob(data).sum()

The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
samples from a probability distribution with constrained ``.support`` are
propagated in an unconstrained space, and algorithms are typically rotation
invariant.::

    dist = Exponential(rate)
    unconstrained = torch.zeros(100, requires_grad=True)
    sample = biject_to(dist.support)(unconstrained)
    potential_energy = -dist.log_prob(sample).sum()

.. note::

    An example where ``transform_to`` and ``biject_to`` differ is
    ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
    :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
    exponentiates and normalizes its inputs; this is a cheap and mostly
    coordinate-wise operation appropriate for algorithms like SVI. In
    contrast, ``biject_to(constraints.simplex)`` returns a
    :class:`~torch.distributions.transforms.StickBreakingTransform` that
    bijects its input down to a one-fewer-dimensional space; this a more
    expensive less numerically stable transform but is needed for algorithms
    like HMC.

The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
constraints and transforms using their ``.register()`` method either as a
function on singleton constraints::

    transform_to.register(my_constraint, my_transform)

or as a decorator on parameterized constraints::

    @transform_to.register(MyConstraintClass)
    def my_factory(constraint):
        assert isinstance(constraint, MyConstraintClass)
        return MyTransform(constraint.param1, constraint.param2)

You can create your own registry by creating a new :class:`ConstraintRegistry`
object.
"""

import numbers

from torch.distributions import constraints, transforms

__all__ = [
    "ConstraintRegistry",
    "biject_to",
    "transform_to",
]


[docs]class ConstraintRegistry: """ Registry to link constraints to transforms. """ def __init__(self): self._registry = {} super().__init__()
[docs] def register(self, constraint, factory=None): """ Registers a :class:`~torch.distributions.constraints.Constraint` subclass in this registry. Usage:: @my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints) Args: constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): A subclass of :class:`~torch.distributions.constraints.Constraint`, or a singleton object of the desired class. factory (Callable): A callable that inputs a constraint object and returns a :class:`~torch.distributions.transforms.Transform` object. """ # Support use as decorator. if factory is None: return lambda factory: self.register(constraint, factory) # Support calling on singleton instances. if isinstance(constraint, constraints.Constraint): constraint = type(constraint) if not isinstance(constraint, type) or not issubclass( constraint, constraints.Constraint ): raise TypeError( f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" ) self._registry[constraint] = factory return factory
def __call__(self, constraint): """ Looks up a transform to constrained space, given a constraint object. Usage:: constraint = Normal.arg_constraints['scale'] scale = transform_to(constraint)(torch.zeros(1)) # constrained u = transform_to(constraint).inv(scale) # unconstrained Args: constraint (:class:`~torch.distributions.constraints.Constraint`): A constraint object. Returns: A :class:`~torch.distributions.transforms.Transform` object. Raises: `NotImplementedError` if no transform has been registered. """ # Look up by Constraint subclass. try: factory = self._registry[type(constraint)] except KeyError: raise NotImplementedError( f"Cannot transform {type(constraint).__name__} constraints" ) from None return factory(constraint)
biject_to = ConstraintRegistry() transform_to = ConstraintRegistry() ################################################################################ # Registration Table ################################################################################ @biject_to.register(constraints.real) @transform_to.register(constraints.real) def _transform_to_real(constraint): return transforms.identity_transform @biject_to.register(constraints.independent) def _biject_to_independent(constraint): base_transform = biject_to(constraint.base_constraint) return transforms.IndependentTransform( base_transform, constraint.reinterpreted_batch_ndims ) @transform_to.register(constraints.independent) def _transform_to_independent(constraint): base_transform = transform_to(constraint.base_constraint) return transforms.IndependentTransform( base_transform, constraint.reinterpreted_batch_ndims ) @biject_to.register(constraints.positive) @biject_to.register(constraints.nonnegative) @transform_to.register(constraints.positive) @transform_to.register(constraints.nonnegative) def _transform_to_positive(constraint): return transforms.ExpTransform() @biject_to.register(constraints.greater_than) @biject_to.register(constraints.greater_than_eq) @transform_to.register(constraints.greater_than) @transform_to.register(constraints.greater_than_eq) def _transform_to_greater_than(constraint): return transforms.ComposeTransform( [ transforms.ExpTransform(), transforms.AffineTransform(constraint.lower_bound, 1), ] ) @biject_to.register(constraints.less_than) @transform_to.register(constraints.less_than) def _transform_to_less_than(constraint): return transforms.ComposeTransform( [ transforms.ExpTransform(), transforms.AffineTransform(constraint.upper_bound, -1), ] ) @biject_to.register(constraints.interval) @biject_to.register(constraints.half_open_interval) @transform_to.register(constraints.interval) @transform_to.register(constraints.half_open_interval) def _transform_to_interval(constraint): # Handle the special case of the unit interval. lower_is_0 = ( isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0 ) upper_is_1 = ( isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1 ) if lower_is_0 and upper_is_1: return transforms.SigmoidTransform() loc = constraint.lower_bound scale = constraint.upper_bound - constraint.lower_bound return transforms.ComposeTransform( [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] ) @biject_to.register(constraints.simplex) def _biject_to_simplex(constraint): return transforms.StickBreakingTransform() @transform_to.register(constraints.simplex) def _transform_to_simplex(constraint): return transforms.SoftmaxTransform() # TODO define a bijection for LowerCholeskyTransform @transform_to.register(constraints.lower_cholesky) def _transform_to_lower_cholesky(constraint): return transforms.LowerCholeskyTransform() @transform_to.register(constraints.positive_definite) @transform_to.register(constraints.positive_semidefinite) def _transform_to_positive_definite(constraint): return transforms.PositiveDefiniteTransform() @biject_to.register(constraints.corr_cholesky) @transform_to.register(constraints.corr_cholesky) def _transform_to_corr_cholesky(constraint): return transforms.CorrCholeskyTransform() @biject_to.register(constraints.cat) def _biject_to_cat(constraint): return transforms.CatTransform( [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths ) @transform_to.register(constraints.cat) def _transform_to_cat(constraint): return transforms.CatTransform( [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths ) @biject_to.register(constraints.stack) def _biject_to_stack(constraint): return transforms.StackTransform( [biject_to(c) for c in constraint.cseq], constraint.dim ) @transform_to.register(constraints.stack) def _transform_to_stack(constraint): return transforms.StackTransform( [transform_to(c) for c in constraint.cseq], constraint.dim )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources