Grad Utils#
Utilities for gradient computation and pytree handling.
- nabla.utils.grad_utils.select_gradients_by_argnums(all_gradients, args, argnums)[source]#
Select gradients based on argnums, matching JAX behavior exactly.
JAX behavior: 1. Single argument: grad(func)(x) -> returns gradient with same structure as x 2. Multiple arguments: grad(func, argnums=i)(x, y, z) -> returns gradient w.r.t. arg i 3. Multiple arguments with multiple argnums: returns tuple of gradients
The key insight: argnums refers to the FUNCTION ARGUMENTS, not elements within a single argument structure.
- nabla.utils.grad_utils.validate_scalar_output(obj)[source]#
Validate that the function output is scalar-like for gradient computation.
- Parameters:
obj (Any) – The function output to validate
- Raises:
ValueError – If the output is not scalar-like