torch.func.vjp¶
- torch.func.vjp(func, *primals, has_aux=False)¶
Standing for the vector-Jacobian product, returns a tuple containing the results of
func
applied toprimals
and a function that, when givencotangents
, computes the reverse-mode Jacobian offunc
with respect toprimals
timescotangents
.- Parameters
func (Callable) – A Python function that takes one or more arguments. Must return one or more Tensors.
primals (Tensors) – Positional arguments to
func
that must all be Tensors. The returned function will also be computing the derivative with respect to these argumentshas_aux (bool) – Flag indicating that
func
returns a(output, aux)
tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns
Returns a
(output, vjp_fn)
tuple containing the output offunc
applied toprimals
and a function that computes the vjp offunc
with respect to allprimals
using the cotangents passed to the returned function. Ifhas_aux is True
, then instead returns a(output, vjp_fn, aux)
tuple. The returnedvjp_fn
function will return a tuple of each VJP.
When used in simple cases,
vjp()
behaves the same asgrad()
>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
However,
vjp()
can support functions with multiple outputs by passing in the cotangents for each of the outputs>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()
can even support outputs being Python structs>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by
vjp()
will compute the partials with respect to each of theprimals
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals
are the positional arguments forf
. All kwargs use their default value>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
Note
Using PyTorch
torch.no_grad
together withvjp
. Case 1: Usingtorch.no_grad
inside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
vjp(f)(x)
will respect the innertorch.no_grad
.Case 2: Using
vjp
insidetorch.no_grad
context manager:>>> with torch.no_grad(): >>> vjp(f)(x)
In this case,
vjp
will respect the innertorch.no_grad
, but not the outer one. This is becausevjp
is a “function transform”: its result should not depend on the result of a context manager outside off
.