Grad

Grad#

nabla.transforms.grad.value_and_grad(fun=None, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#

Creates a function that evaluates both the value and gradient of fun.

This function uses VJP (Vector-Jacobian Product) directly with a cotangent of ones_like(output) to compute gradients for scalar-valued functions. This is simpler and more efficient than using jacrev/jacfwd for scalar outputs.

Parameters:
  • fun (Callable | None) – Function to be differentiated. Should return a scalar.

  • argnums (int | Sequence[int]) – Which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Whether fun returns (output, aux) pair (default False).

  • holomorphic (bool) – Whether fun is holomorphic - currently ignored (default False).

  • allow_int (bool) – Whether to allow integer inputs - currently ignored (default False).

  • reduce_axes (Sequence) – Axes to reduce over - currently ignored (default ()).

Returns:

A function that computes both the value and gradient of fun.

Return type:

Callable[[…], Any]

Examples

Basic usage as a function call:

value_and_grad_fn = value_and_grad(my_loss)
value, grads = value_and_grad_fn(x)

Usage as a decorator:

@value_and_grad
def my_loss(x):
    return x**2

value, grads = my_loss(3.0)
nabla.transforms.grad.grad(fun=None, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=(), mode='reverse')[source]#

Creates a function that evaluates the gradient of fun.

This is implemented as a special case of value_and_grad that only returns the gradient part. Uses VJP directly for efficiency with scalar outputs.

Parameters:
  • fun (Callable | None) – Function to be differentiated. Should return a scalar.

  • argnums (int | Sequence[int]) – Which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Whether fun returns (output, aux) pair (default False).

  • holomorphic (bool) – Whether fun is holomorphic - currently ignored (default False).

  • allow_int (bool) – Whether to allow integer inputs - currently ignored (default False).

  • reduce_axes (Sequence) – Axes to reduce over - currently ignored (default ()).

  • mode (str) – Kept for API compatibility but ignored (always uses reverse-mode VJP).

  • Returns – A function that computes the gradient of fun.

  • Examples

    Basic usage as a function call:

    grad_fn = grad(my_loss)
    grads = grad_fn(x)
    

    Usage as a decorator:

    @grad
    def my_loss(x):
        return x**2
    
    grads = my_loss(3.0)  # Returns gradient, not function value