LnStructured¶
- class torch.nn.utils.prune.LnStructured(amount, n, dim=-1)[source]¶
Prune entire (currently unpruned) channels in a tensor based on their L
n
-norm.- Parameters
amount (int or float) – quantity of channels to prune. If
float
, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. Ifint
, it represents the absolute number of parameters to prune.n (int, float, inf, -inf, 'fro', 'nuc') – See documentation of valid entries for argument
p
intorch.norm()
.dim (int, optional) – index of the dim along which we define channels to prune. Default: -1.
- classmethod apply(module, name, amount, n, dim, importance_scores=None)[source]¶
Add pruning on the fly and reparametrization of a tensor.
Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- Parameters
module (nn.Module) – module containing the tensor to prune
name (str) – parameter name within
module
on which pruning will act.amount (int or float) – quantity of parameters to prune. If
float
, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. Ifint
, it represents the absolute number of parameters to prune.n (int, float, inf, -inf, 'fro', 'nuc') – See documentation of valid entries for argument
p
intorch.norm()
.dim (int) – index of the dim along which we define channels to prune.
importance_scores (torch.Tensor) – tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the module parameter will be used in its place.
- apply_mask(module)¶
Simply handles the multiplication between the parameter being pruned and the generated mask.
Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- Parameters
module (nn.Module) – module containing the tensor to prune
- Returns
pruned version of the input tensor
- Return type
pruned_tensor (torch.Tensor)
- compute_mask(t, default_mask)[source]¶
Compute and returns a mask for the input tensor
t
.Starting from a base
default_mask
(which should be a mask of ones if the tensor has not been pruned yet), generate a mask to apply on top of thedefault_mask
by zeroing out the channels along the specified dim with the lowest Ln
-norm.- Parameters
t (torch.Tensor) – tensor representing the parameter to prune
default_mask (torch.Tensor) – Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as
t
.
- Returns
mask to apply to
t
, of same dims ast
- Return type
mask (torch.Tensor)
- Raises
IndexError – if
self.dim >= len(t.shape)
- prune(t, default_mask=None, importance_scores=None)¶
Compute and returns a pruned version of input tensor
t
.According to the pruning rule specified in
compute_mask()
.- Parameters
t (torch.Tensor) – tensor to prune (of same dimensions as
default_mask
).importance_scores (torch.Tensor) – tensor of importance scores (of same shape as
t
) used to compute mask for pruningt
. The values in this tensor indicate the importance of the corresponding elements in thet
that is being pruned. If unspecified or None, the tensort
will be used in its place.default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones.
- Returns
pruned version of tensor
t
.
- remove(module)¶
Remove the pruning reparameterization from a module.
The pruned parameter named
name
remains permanently pruned, and the parameter namedname+'_orig'
is removed from the parameter list. Similarly, the buffer namedname+'_mask'
is removed from the buffers.Note
Pruning itself is NOT undone or reversed!