Grad#
- nabla.transforms.grad.value_and_grad(fun=None, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
Creates a function that evaluates both the value and gradient of fun.
This function uses VJP (Vector-Jacobian Product) directly with a cotangent of ones_like(output) to compute gradients for scalar-valued functions. This is simpler and more efficient than using jacrev/jacfwd for scalar outputs.
- Parameters:
fun (Callable | None) – Function to be differentiated. Should return a scalar.
argnums (int | Sequence[int]) – Which positional argument(s) to differentiate with respect to (default 0).
has_aux (bool) – Whether fun returns (output, aux) pair (default False).
holomorphic (bool) – Whether fun is holomorphic - currently ignored (default False).
allow_int (bool) – Whether to allow integer inputs - currently ignored (default False).
reduce_axes (Sequence) – Axes to reduce over - currently ignored (default ()).
- Returns:
A function that computes both the value and gradient of fun.
- Return type:
Examples
Basic usage as a function call:
value_and_grad_fn = value_and_grad(my_loss) value, grads = value_and_grad_fn(x)
Usage as a decorator:
@value_and_grad def my_loss(x): return x**2 value, grads = my_loss(3.0)
- nabla.transforms.grad.grad(fun=None, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=(), mode='reverse')[source]#
Creates a function that evaluates the gradient of fun.
This is implemented as a special case of value_and_grad that only returns the gradient part. Uses VJP directly for efficiency with scalar outputs.
- Parameters:
fun (Callable | None) – Function to be differentiated. Should return a scalar.
argnums (int | Sequence[int]) – Which positional argument(s) to differentiate with respect to (default 0).
has_aux (bool) – Whether fun returns (output, aux) pair (default False).
holomorphic (bool) – Whether fun is holomorphic - currently ignored (default False).
allow_int (bool) – Whether to allow integer inputs - currently ignored (default False).
reduce_axes (Sequence) – Axes to reduce over - currently ignored (default ()).
mode (str) – Kept for API compatibility but ignored (always uses reverse-mode VJP).
Returns – A function that computes the gradient of fun.
Examples –
Basic usage as a function call:
grad_fn = grad(my_loss) grads = grad_fn(x)
Usage as a decorator:
@grad def my_loss(x): return x**2 grads = my_loss(3.0) # Returns gradient, not function value