Grad Utils#

Utilities for gradient computation and pytree handling.

nabla.utils.grad_utils.select_gradients_by_argnums(all_gradients, args, argnums)[source]#

Select gradients based on argnums, matching JAX behavior exactly.

JAX behavior: 1. Single argument: grad(func)(x) -> returns gradient with same structure as x 2. Multiple arguments: grad(func, argnums=i)(x, y, z) -> returns gradient w.r.t. arg i 3. Multiple arguments with multiple argnums: returns tuple of gradients

The key insight: argnums refers to the FUNCTION ARGUMENTS, not elements within a single argument structure.

Parameters:
  • all_gradients (Any) – The gradients returned by VJP (preserves input structure)

  • args (tuple) – The original function arguments

  • argnums (int | Sequence[int]) – Which function arguments to compute gradients for

Returns:

Selected gradients matching JAX behavior exactly

Return type:

Any

nabla.utils.grad_utils.validate_scalar_output(obj)[source]#

Validate that the function output is scalar-like for gradient computation.

Parameters:

obj (Any) – The function output to validate

Raises:

ValueError – If the output is not scalar-like

nabla.utils.grad_utils.create_ones_like_cotangent(obj)[source]#

Create a cotangent with ones_like for each Array leaf in the structure.

Parameters:

obj (Any) – The object to create cotangent for

Returns:

Cotangent with same structure but ones_like for Array leaves

Return type:

Any