# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from collections.abc import Callable
from typing import Any
from .utils import (
_extract_arrays_from_pytree,
_std_basis,
make_traced_pytree,
make_untraced_pytree,
)
from .vjp import vjp
from .vmap import vmap
[docs]
def jacrev(
func: Callable[..., Any],
argnums: int | tuple[int, ...] | list[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
) -> Callable[..., Any]:
"""Compute the Jacobian of a function using reverse-mode autodiff.
Args:
func: Function to differentiate (should take positional arguments)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether `func` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether `func` is promised to be
holomorphic. Default False. Currently ignored.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. Currently ignored.
Returns:
A function with the same arguments as `func`, that evaluates the Jacobian of
`func` using reverse-mode automatic differentiation. If `has_aux` is True
then a pair of (jacobian, auxiliary_data) is returned.
Note:
This follows JAX's jacrev API:
- Only accepts positional arguments
- For functions requiring keyword arguments, use functools.partial or lambda
- Returns the Jacobian as a pytree structure matching the input structure
"""
def jacrev_fn(*args: Any) -> Any:
# print("\nSTART JACREV FN")
# Normalize argnums to a tuple of integers
selected_argnums = (argnums,) if isinstance(argnums, int) else tuple(argnums)
# Validate argnums
for argnum in selected_argnums:
if argnum >= len(args) or argnum < -len(args):
raise ValueError(
f"argnum {argnum} is out of bounds for function with {len(args)} arguments"
)
# Normalize negative indices
normalized_argnums = tuple(
argnum if argnum >= 0 else len(args) + argnum for argnum in selected_argnums
)
# Extract the arguments to differentiate with respect to
diff_args = tuple(args[i] for i in normalized_argnums)
# Create a function that takes only the differentiated arguments
def partial_func(*diff_args_inner):
# Reconstruct the full argument list
full_args = list(args)
for i, arg in zip(normalized_argnums, diff_args_inner, strict=False):
full_args[i] = arg
return func(*full_args)
# Compute VJP - delegate has_aux handling to vjp
vjp_result = vjp(partial_func, *diff_args, has_aux=has_aux)
if has_aux:
y, pullback, aux = vjp_result # type: ignore
else:
y, pullback = vjp_result # type: ignore
# Flatten output arrays for std_basis generation
flat_y = _extract_arrays_from_pytree(y)
if not isinstance(flat_y, list):
flat_y = [flat_y]
# Generate standard basis vectors and get sizes for split operations
sizes, std_basis_vectors = _std_basis(flat_y) # type: ignore
std_basis_flat = _extract_arrays_from_pytree(std_basis_vectors)
if not isinstance(std_basis_flat, list):
std_basis_flat = [std_basis_flat]
# Handle mixed scalar/tensor outputs by creating appropriate in_axes specification
if all(arr.shape == () for arr in std_basis_flat):
# All outputs are scalar - use in_axes=None to broadcast
grads = vmap(pullback, in_axes=None)(std_basis_vectors)
elif any(arr.shape == () for arr in std_basis_flat):
# Mixed scalar/tensor outputs - create in_axes specification for each element
# Note: std_basis_vectors is a list/tuple, so in_axes should match that structure
if isinstance(std_basis_vectors, list | tuple):
in_axes_spec = [
None if arr.shape == () else 0 for arr in std_basis_flat
]
grads = vmap(pullback, in_axes=in_axes_spec)(std_basis_vectors)
else:
# Single element case - shouldn't happen with mixed outputs, but handle for completeness
in_axes_spec = None if std_basis_flat[0].shape == () else 0
grads = vmap(pullback, in_axes=in_axes_spec)(std_basis_vectors)
else:
# All outputs are tensors - use in_axes=0 to vectorize along the first axis
grads = vmap(pullback)(std_basis_vectors)
# CRITICAL: Check if std_basis_vectors were traced (indicating composition with other transformations)
std_basis_arrays = _extract_arrays_from_pytree(std_basis_vectors)
any_std_basis_traced = any(
getattr(arr, "traced", False) for arr in std_basis_arrays
)
# Make grads traced to capture subsequent operations in the computation graph
if not any_std_basis_traced:
# Only make traced if original std_basis wasn't traced (avoid double tracing)
grads = make_traced_pytree(grads)
# Import split function for proper jacobian structuring
from ..ops.view import reshape, split
# Extract flat input arguments for reshaping
flat_diff_args = _extract_arrays_from_pytree(diff_args)
splits = []
for i in range(len(flat_diff_args)): # For each input argument
if isinstance(grads, list) and len(grads) > 0:
if isinstance(grads[0], tuple):
# Multiple inputs: extract i-th input's gradients from each batch
input_grads = grads[0][i] # All batched gradients for input i
else:
# Single input case
input_grads = grads[0] if len(flat_diff_args) == 1 else grads[i]
else:
# Direct case
input_grads = grads[i] if isinstance(grads, tuple) else grads
# Split this input's gradients by output components (now traced!)
splits.append(split(input_grads, sizes=sizes, axis=0)) # type: ignore
# Reshape jacobian components to proper out_shape + arg_shape format (now traced!)
cotangents = []
for j in range(len(flat_y)): # For each output component
arg_jacs = []
for i in range(len(flat_diff_args)): # For each input argument
grad = splits[i][j] # j-th output component for i-th input
batch_dims = flat_y[j].batch_dims
out_shape = flat_y[j].shape
arg_shape = flat_diff_args[i].shape
# print("out_shape:", out_shape, "in_shape:", arg_shape)
# Only remove (1,) from output shape when we have batch dimensions (from vmap)
# This handles the case where scalar functions return (1,) instead of ()
if len(batch_dims) > 0 and len(out_shape) == 1 and out_shape[0] == 1:
out_shape = ()
# Never remove (1,) from arg_shape - it represents valid jacobian structure
# Jacobian shape should be output_shape + input_shape
target_shape = out_shape + arg_shape
reshaped_grad = reshape(grad, target_shape) # Now traced!
arg_jacs.append(reshaped_grad)
if len(arg_jacs) == 1:
arg_jacs = arg_jacs[0] # Single input case, return single jacobian
cotangents.append(arg_jacs)
final_jac = cotangents
# print(len(cotangents))
if len(cotangents) == 1:
final_jac = cotangents[0]
# Make final jacobian untraced unless we're in a composition context
if not any_std_basis_traced:
make_untraced_pytree(final_jac)
# print("\nEND JACREV FN\n")
if not has_aux:
return final_jac
else:
return final_jac, aux
return jacrev_fn