torch.utils.checkpoint¶
Note
Checkpointing is implemented by rerunning a forward-pass segment for
each checkpointed segment during backward. This can cause persistent
states like the RNG state to be advanced than they would without
checkpointing. By default, checkpointing includes logic to juggle
the RNG state such that checkpointed passes making use of RNG
(through dropout for example) have deterministic output as
compared to non-checkpointed passes. The logic to stash and restore
RNG states can incur a moderate performance hit depending on the runtime
of checkpointed operations. If deterministic output compared to
non-checkpointed passes is not required, supply preserve_rng_state=False
to checkpoint
or checkpoint_sequential
to omit stashing and
restoring the RNG state during each checkpoint.
The stashing logic saves and restores the RNG state for the current device
and the device of all cuda Tensor arguments to the run_fn
.
However, the logic has no way to anticipate if the user will move
Tensors to a new device within the run_fn
itself. Therefore, if you move
Tensors to a new device (“new” meaning not belonging to the set of
[current device + devices of Tensor arguments]) within run_fn
, deterministic
output compared to non-checkpointed passes is never guaranteed.
-
torch.utils.checkpoint.
checkpoint
(function, *args, use_reentrant=True, **kwargs)[source]¶ Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.
Specifically, in the forward pass,
function
will run intorch.no_grad()
manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and thefunction
parameter. In the backwards pass, the saved inputs andfunction
is retrieved, and the forward pass is computed onfunction
again, now tracking the intermediate activations, and then the gradients are calculated using these activation values.The output of
function
can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd.Warning
If
function
invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won’t be equivalent, and unfortunately it can’t be detected.Warning
If
use_reentrant=True
is specified, then if the checkpointed segment contains tensors detached from the computational graph by detach() or torch.no_grad(), the backward pass will raise an error. This is because checkpoint makes all the outputs require gradients which causes issues when a tensor is defined to have no gradient in the model. To circumvent this, detach the tensors outside of the checkpoint function. Note that the checkpointed segment can contain tensors detached from the computational graph ifuse_reentrant=False
is specified.Warning
If
use_reentrant=True
is specified, at least one of the inputs needs to haverequires_grad=True
if grads are needed for model inputs, otherwise the checkpointed part of the model won’t have gradients. At least one of the outputs needs to haverequires_grad=True
as well. Note that this does not apply ifuse_reentrant=False
is specified.Warning
If
use_reentrant=True
is specified, checkpointing currently only supportstorch.autograd.backward()
and only if its inputs argument is not passed.torch.autograd.grad()
is not supported. Ifuse_reentrant=False
is specified, checkpointing will work withtorch.autograd.grad()
.- Parameters
function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes
(activation, hidden)
,function
should correctly use the first input asactivation
and the second input ashidden
preserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
use_reentrant (bool, optional, default=True) – Use checkpointing implementation that requires re-entrant autograd. If
use_reentrant=False
is specified,checkpoint
will use an implementation that does not require re-entrant autograd. This allowscheckpoint
to support additional functionality, such as working as expected withtorch.autograd.grad
. Note that future versions of PyTorch will default touse_reentrant=False
.args – tuple containing inputs to the
function
- Returns
Output of running
function
on*args
-
torch.utils.checkpoint.
checkpoint_sequential
(functions, segments, input, **kwargs)[source]¶ A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in
torch.no_grad()
manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.See
checkpoint()
on how checkpointing works.Warning
Checkpointing currently only supports
torch.autograd.backward()
and only if its inputs argument is not passed.torch.autograd.grad()
is not supported.- Parameters
functions – A
torch.nn.Sequential
or the list of modules or functions (comprising the model) to run sequentially.segments – Number of chunks to create in the model
input – A Tensor that is input to
functions
preserve_rng_state (bool, optional, default=True) – Omit stashing and restoring the RNG state during each checkpoint.
- Returns
Output of running
functions
sequentially on*inputs
Example
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)