# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from collections.abc import Callable
from typing import Any
from .jvp import jvp
from .utils import (
_extract_arrays_from_pytree,
_std_basis,
)
from .vmap import vmap
[docs]
def jacfwd(
func: Callable[..., Any],
argnums: int | tuple[int, ...] | list[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
) -> Callable[..., Any]:
"""
Prototype implementation of jacfwd using forward-mode autodiff.
This computes the Jacobian using the pattern:
vmap(jvp(func, primals, tangents), in_axes=(primal_axes, tangent_axes))
where primal_axes are None (broadcast) and tangent_axes are 0 (vectorize).
Args:
func: Function to differentiate
argnums: Which arguments to differentiate with respect to
has_aux: Whether function returns auxiliary data
holomorphic: Ignored (for JAX compatibility)
allow_int: Ignored (for JAX compatibility)
Returns:
Function that computes the Jacobian using forward-mode autodiff
"""
def jacfwd_fn(*args: Any) -> Any:
# print(f"\n=== JACFWD PROTOTYPE ===")
# print(f"Input args shapes: {[arg.shape if hasattr(arg, 'shape') else type(arg).__name__ for arg in args]}")
# Normalize argnums to a tuple of integers (same as jacrev)
selected_argnums = (argnums,) if isinstance(argnums, int) else tuple(argnums)
# Validate argnums (same as jacrev)
for argnum in selected_argnums:
if argnum >= len(args) or argnum < -len(args):
raise ValueError(
f"argnum {argnum} is out of bounds for function with {len(args)} arguments"
)
# Normalize negative indices (same as jacrev)
normalized_argnums = tuple(
argnum if argnum >= 0 else len(args) + argnum for argnum in selected_argnums
)
# print(f"Differentiating w.r.t. arguments: {normalized_argnums}")
# Extract the arguments to differentiate with respect to (same as jacrev)
diff_args = tuple(args[i] for i in normalized_argnums)
# print(f"Diff args shapes: {[arg.shape for arg in diff_args]}")
# Create a function that takes only the differentiated arguments (same as jacrev)
def partial_func(*diff_args_inner):
# Reconstruct the full argument list
full_args = list(args)
for i, arg in zip(normalized_argnums, diff_args_inner, strict=False):
full_args[i] = arg
return func(*full_args)
# Generate standard basis vectors for the INPUT arguments (key difference from jacrev)
flat_diff_args = _extract_arrays_from_pytree(diff_args)
if not isinstance(flat_diff_args, list):
flat_diff_args = [flat_diff_args]
# print(f"Flat diff args shapes: {[arg.shape for arg in flat_diff_args]}")
# Create standard basis vectors for inputs (this is the key difference from jacrev)
sizes, std_basis_vectors = _std_basis(flat_diff_args) # type: ignore
# print(f"Standard basis sizes: {sizes}")
# print(f"Standard basis vectors shape: {std_basis_vectors[0].shape if std_basis_vectors else 'None'}") # Create the JVP function that we'll vmap over
# This function takes the individual arguments from diff_args + one tangent per input
def jvp_func(*args):
"""
JVP function that computes output tangents.
For single input: args = (primal, tangent_vector)
For multi-input: args = (primal1, primal2, ..., tangent1, tangent2, ...)
The tangent vectors come from _std_basis and are already properly shaped.
"""
num_primals = len(diff_args)
primals = args[:num_primals] # First N arguments are primals
tangent_vectors = args[num_primals:] # Last N arguments are tangents
if len(primals) == 1:
# Single input case
tangents_tuple = tangent_vectors[0]
primals_tuple = primals[0]
else:
# Multi-input case
tangents_tuple = tuple(tangent_vectors)
primals_tuple = tuple(primals)
# Compute JVP: jvp(partial_func, primals, tangents)
jvp_result = jvp(partial_func, primals_tuple, tangents_tuple)
primal_out, tangent_out = jvp_result # type: ignore
return tangent_out # Return tangent output directly
# Create in_axes: None for each primal argument, 0 for each tangent vector
primals_axes = tuple(None for _ in diff_args) # Broadcast all primal arguments
tangents_axes = tuple(
0 for _ in std_basis_vectors
) # Vectorize all tangent arguments
vmap_in_axes = primals_axes + tangents_axes
# Apply vmap to vectorize the JVP computation
# print(f"vmap in_axes: {vmap_in_axes}")
vmap_jvp = vmap(jvp_func, in_axes=vmap_in_axes)
output_tangents = vmap_jvp(*diff_args, *std_basis_vectors)
from nabla.ops.view import reshape, split
# Get output structure by running the function once
test_output = partial_func(*diff_args)
flat_output = _extract_arrays_from_pytree(test_output)
if not isinstance(flat_output, list):
flat_output = [flat_output]
# Split the output tangents by the sizes from _std_basis
split_tangents = split(output_tangents, sizes=sizes, axis=0)
# print("\n\nSPLIT TANGENTS")
# print(split_tangents)
# print("\n\n")
jacobian_components = []
for _j, (arg, tangents_for_arg) in enumerate(
zip(flat_diff_args, split_tangents, strict=False)
):
output_shape = flat_output[0].shape
arg_shape = arg.shape
# Reshape to proper Jacobian format: output_shape + input_shape
target_shape = arg_shape + output_shape
jacobian_component = reshape(tangents_for_arg, target_shape)
# reshaped_grad = grad.reshape(shape)
perm_axes = []
for k in range(len(output_shape)):
perm_axes.append(k + len(arg_shape))
for k in range(len(arg_shape)):
perm_axes.append(k)
from ..ops.view import permute
jacobian_component = permute(jacobian_component, tuple(perm_axes))
jacobian_components.append(jacobian_component)
# Return as tuple for multiple inputs
if len(jacobian_components) == 1:
jacobian_components = jacobian_components[0]
jacobian = jacobian_components
if not has_aux:
return jacobian
else:
# TODO: Handle auxiliary data properly
return jacobian, None
return jacfwd_fn