Source code for nabla.utils.grad_utils
# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Utilities for gradient computation and pytree handling."""
from collections.abc import Sequence
from typing import Any, Union
[docs]
def select_gradients_by_argnums(
all_gradients: Any, args: tuple, argnums: Union[int, Sequence[int]]
) -> Any:
"""
Select gradients based on argnums, matching JAX behavior exactly.
JAX behavior:
1. Single argument: grad(func)(x) -> returns gradient with same structure as x
2. Multiple arguments: grad(func, argnums=i)(x, y, z) -> returns gradient w.r.t. arg i
3. Multiple arguments with multiple argnums: returns tuple of gradients
The key insight: argnums refers to the FUNCTION ARGUMENTS, not elements within
a single argument structure.
Parameters:
all_gradients: The gradients returned by VJP (preserves input structure)
args: The original function arguments
argnums: Which function arguments to compute gradients for
Returns:
Selected gradients matching JAX behavior exactly
"""
num_inputs = len(args)
# Normalize argnums to sequence for uniform handling
if isinstance(argnums, int):
argnums_seq = [argnums]
return_single = True
else:
argnums_seq = list(argnums)
return_single = False
# Validate argnums are within bounds
for idx in argnums_seq:
if idx < 0 or idx >= num_inputs:
raise ValueError(
f"argnum {idx} is out of bounds for function with {num_inputs} arguments"
)
if num_inputs == 1:
# Single input case - return full gradient structure for argnum=0
# This matches JAX: grad(func)(pytree) -> returns gradient with same structure
if argnums_seq == [0]:
return all_gradients if return_single else (all_gradients,)
else:
# Invalid argnum for single input
invalid = [idx for idx in argnums_seq if idx != 0]
raise ValueError(
f"argnums {invalid} are out of bounds for function with 1 argument"
)
else:
# Multiple input case - select gradients by argument position
# all_gradients is a tuple with one gradient per argument
selected = [all_gradients[i] for i in argnums_seq]
return selected[0] if return_single else tuple(selected)
[docs]
def validate_scalar_output(obj: Any) -> None:
"""
Validate that the function output is scalar-like for gradient computation.
Parameters:
obj: The function output to validate
Raises:
ValueError: If the output is not scalar-like
"""
from ..core.array import Array
if isinstance(obj, Array):
# JAX behavior: allow both () and (1,) shapes as "scalar-like"
if obj.shape != () and obj.shape != (1,):
raise ValueError(
f"Gradient only defined for scalar-output functions. "
f"Found array with shape: {obj.shape}"
)
elif isinstance(obj, list | tuple):
for item in obj:
validate_scalar_output(item)
elif isinstance(obj, dict):
for value in obj.values():
validate_scalar_output(value)
else:
# Handle non-Array outputs (like numpy arrays, Python scalars)
import numpy as np
test_array = np.asarray(obj)
if test_array.shape != () and test_array.shape != (1,):
raise ValueError(
f"Gradient only defined for scalar-output functions. "
f"Found non-scalar with shape: {test_array.shape}"
)
[docs]
def create_ones_like_cotangent(obj: Any) -> Any:
"""
Create a cotangent with ones_like for each Array leaf in the structure.
Parameters:
obj: The object to create cotangent for
Returns:
Cotangent with same structure but ones_like for Array leaves
"""
from ..core.array import Array
from ..ops.creation import ones_like
if isinstance(obj, Array):
return ones_like(obj)
elif isinstance(obj, list | tuple):
return type(obj)(create_ones_like_cotangent(item) for item in obj)
elif isinstance(obj, dict):
return {k: create_ones_like_cotangent(v) for k, v in obj.items()}
else:
# For non-Array leaves, we don't need cotangents
return obj