Vjp

Contents

Vjp#

nabla.transforms.vjp.vjp(func: Callable[[...], Any], *primals, has_aux: Literal[False] = False) tuple[Any, Callable][source]#
nabla.transforms.vjp.vjp(func: Callable[[...], Any], *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]

Compute vector-Jacobian product (reverse-mode autodiff).

Parameters:
  • func – Function to differentiate (should take positional arguments)

  • *primals – Positional arguments to the function (can be arbitrary pytrees)

  • has_aux – Optional, bool. Indicates whether func returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

Returns:

Tuple of (outputs, vjp_function) where vjp_function computes gradients. If has_aux is True:

Tuple of (outputs, vjp_function, aux) where aux is the auxiliary data.

The vjp_function always returns gradients as a tuple (matching JAX behavior): - Single argument: vjp_fn(cotangent) -> (gradient,) - Multiple arguments: vjp_fn(cotangent) -> (grad1, grad2, …)

Note: This follows JAX’s vjp API exactly: - Only accepts positional arguments - Always returns gradients as tuple - For functions requiring keyword arguments, use functools.partial or lambda

Return type:

If has_aux is False