# ===----------------------------------------------------------------------=== #
# Nabla 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Core transformations for automatic differentiation and tracing."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from ..core.array import Array
[docs]
def tree_flatten(tree: Any) -> tuple[list[Array], Any]:
"""Flatten a pytree into a list of Arrays and structure info.
Args:
tree: A pytree containing Arrays and other structures
Returns:
A tuple of (list of Array leaves, structure info for reconstruction)
"""
leaves = []
def _flatten(obj: Any) -> Any:
if isinstance(obj, Array):
leaves.append(obj)
return None # Placeholder for Array
elif isinstance(obj, dict):
keys = sorted(obj.keys()) # Deterministic ordering
return {k: _flatten(obj[k]) for k in keys}
elif isinstance(obj, (list | tuple)):
return type(obj)(_flatten(item) for item in obj)
else:
# Non-Array leaf (int, float, etc.)
return obj
structure = _flatten(tree)
return leaves, structure
[docs]
def tree_unflatten(structure: Any, leaves: list[Array]) -> Any:
"""Reconstruct a pytree from structure info and list of Arrays.
Args:
structure: Structure info from tree_flatten
leaves: List of Array values to place at Array positions
Returns:
Reconstructed pytree with the same structure as the original
"""
leaves_iter = iter(leaves)
def _unflatten(struct: Any) -> Any:
if struct is None: # Array placeholder
try:
return next(leaves_iter)
except StopIteration:
raise ValueError(
f"Tree unflatten error: Not enough leaves. Expected structure: {structure}, Got {len(leaves)} leaves"
)
elif isinstance(struct, dict):
return {k: _unflatten(v) for k, v in struct.items()}
elif isinstance(struct, list | tuple):
# Use list comprehension instead of generator to avoid StopIteration -> RuntimeError conversion
try:
result = [_unflatten(item) for item in struct]
return type(struct)(result)
except StopIteration:
raise ValueError(
f"Tree unflatten error: Not enough leaves for sequence. Expected structure: {structure}, Got {len(leaves)} leaves"
)
else:
# Non-Array leaf
return struct
result = _unflatten(structure)
# Verify we consumed all leaves
try:
next(leaves_iter)
raise ValueError("Too many leaves provided for tree structure")
except StopIteration:
pass
return result
[docs]
def tree_map(func: Callable[[Array], Array], tree: Any) -> Any:
"""Apply a function to all Array leaves in a pytree.
Args:
func: Function to apply to each Array leaf
tree: Pytree containing Arrays
Returns:
Pytree with the same structure but transformed Arrays
"""
leaves, structure = tree_flatten(tree)
transformed_leaves = [func(leaf) for leaf in leaves]
return tree_unflatten(structure, transformed_leaves)
def _extract_arrays_from_pytree(tree: Any) -> list[Array]:
"""Extract all Arrays from a pytree structure.
Args:
tree: Pytree that may contain Arrays, ints, floats, etc.
Returns:
List of all Arrays found in the tree
"""
leaves, _ = tree_flatten(tree)
return leaves
def _validate_length_match(list1, list2, name1, name2):
"""Check if two lists have the same length."""
if len(list1) != len(list2):
raise ValueError(f"{name1} length {len(list1)} != {name2} length {len(list2)}")
def _std_basis(args: list[Array]) -> tuple[list[int], list[Array]]:
num_total_arg_elements = 0
max_rank = 0
for arg in args:
num_elements = 1
for dim in arg.shape:
num_elements *= dim
num_total_arg_elements += num_elements
rank = len(arg.shape)
if rank > max_rank:
max_rank = rank
batch_ctr = 0
sizes = list[int]()
tangents: list[Array] = []
for _i, arg in enumerate(args):
num_elements = 1
if arg.shape == ():
from ..ops.creation import ones_like
tangent = ones_like(arg)
tangents.append(tangent)
sizes.append(1)
batch_ctr += 1
else:
for dim in arg.shape:
num_elements *= dim
batched_shape = (num_total_arg_elements,) + arg.shape
from numpy import zeros as np_zeros
np_tangent = np_zeros(batched_shape, dtype=arg.dtype.to_numpy()).flatten()
offset = batch_ctr * num_elements
for j in range(num_elements):
idx = offset + j * num_elements + j
np_tangent[idx] = 1.0
batch_ctr += 1
np_tangent = np_tangent.reshape(batched_shape)
tangent = Array.from_numpy(np_tangent)
from ..ops.view import broadcast_batch_dims
tangent = broadcast_batch_dims(tangent, arg.batch_dims)
tangents.append(tangent)
sizes.append(num_elements)
return sizes, tangents
[docs]
def make_traced_pytree(tree: Any) -> Any:
"""Create shallow copies of arrays in a pytree and mark them as traced.
Args:
tree: Pytree containing Arrays to copy and mark as traced
Returns:
Pytree with the same structure but traced Arrays
"""
def _make_traced_array(array: Array) -> Array:
from ..ops.view import shallow_copy
copied_arg = shallow_copy(array)
copied_arg.traced = True
return copied_arg
return tree_map(_make_traced_array, tree)
[docs]
def make_untraced_pytree(tree: Any) -> None:
"""Disable tracing for arrays in a pytree by clearing their traced flag.
Args:
tree: Pytree containing Arrays to disable tracing for
"""
def _make_untraced_array(array: Array) -> Array:
array.traced = False
return array
tree_map(_make_untraced_array, tree)
[docs]
def make_staged_pytree(args: list[Array]) -> None:
"""Enable staged execution for arrays to optimize performance.
Args:
args: Arrays to enable staged execution for
"""
def _make_staged_array(array: Array) -> Array:
array.stage_realization = True
return array
tree_map(_make_staged_array, args)
[docs]
def make_unstaged_pytree(args: list[Array]) -> None:
"""Disable staged execution for arrays.
Args:
args: Arrays to disable staged execution for
"""
def _make_unstaged_array(array: Array) -> Array:
array.stage_realization = False
return array
tree_map(_make_unstaged_array, args)
def _handle_args_consistently(args):
"""Handle both fn([x,y,z]) and fn(x,y,z) calling styles."""
if len(args) == 1 and isinstance(args[0], list):
return args[0], True
return args, False
def _prepare_traced_inputs(
actual_args, is_list_style, apply_staging=False, with_conversion=False
):
"""Prepare traced inputs for list-style or pytree-style arguments."""
# Convert scalars to Arrays if requested
if with_conversion:
def convert_scalars_to_arrays(item):
if isinstance(item, Array):
return item
elif isinstance(item, list | tuple):
return type(item)(
convert_scalars_to_arrays(sub_item) for sub_item in item
)
elif isinstance(item, dict):
return {k: convert_scalars_to_arrays(v) for k, v in item.items()}
elif isinstance(item, (int, float, bool)):
# Only convert basic scalar types to Nabla Arrays
import nabla as nb
return nb.array(item)
else:
# Keep everything else (functions, numpy arrays, etc.) unchanged
return item
actual_args = convert_scalars_to_arrays(actual_args)
if is_list_style:
traced_args = make_traced_pytree(actual_args)
if apply_staging:
make_staged_pytree(traced_args)
return traced_args, None
# Handle the case where actual_args might not have __len__
if hasattr(actual_args, "__len__"):
args_len = len(actual_args) # type: ignore
else:
args_len = 1
if args_len == 1:
inputs_pytree = actual_args[0]
traced_inputs_pytree = make_traced_pytree(inputs_pytree)
traced_args = (traced_inputs_pytree,)
else:
inputs_pytree = actual_args
traced_inputs_pytree = make_traced_pytree(inputs_pytree)
traced_args = traced_inputs_pytree
if apply_staging:
# Apply staging to the TRACED arrays, not the original args
arrays = _extract_arrays_from_pytree(traced_args)
make_staged_pytree(arrays)
return traced_args, traced_inputs_pytree
def _clean_traced_outputs(outputs, is_list_style, remove_staging=False):
"""Clean up traced outputs and handle staging flags."""
if is_list_style:
# For list-style, we expect a list of Arrays, but handle tuple case
if isinstance(outputs, list):
make_untraced_pytree(outputs)
if remove_staging:
make_unstaged_pytree(outputs)
else:
# If it's not a list (e.g., tuple from VJP), treat as pytree
make_untraced_pytree(outputs)
if remove_staging:
output_arrays = _extract_arrays_from_pytree(outputs)
make_unstaged_pytree(output_arrays)
else:
make_untraced_pytree(outputs)
if remove_staging:
output_arrays = _extract_arrays_from_pytree(outputs)
make_unstaged_pytree(output_arrays)
return outputs
[docs]
class Trace:
"""A simple trace container that holds the computation graph."""
[docs]
def __init__(self, inputs: list[Array], outputs: list[Array] | None = None) -> None:
self.inputs = inputs
self.outputs = outputs if outputs is not None else []
self.trace: list[Array] = []
self._computed = False
# Mark all inputs as traced for autodiff so the computation graph gets captured
for inp in inputs:
inp.traced = True
[docs]
@classmethod
def trace_function(
cls, fn: Callable[[list[Array]], list[Array]], inputs: list[Array]
) -> Trace:
"""
Create a trace by executing a function with tracing enabled.
This is the recommended way to create traces as it ensures proper
tracing setup before function execution.
"""
inputs = make_traced_pytree(inputs)
# Create trace instance (this marks inputs as traced)
trace = cls(inputs)
# Execute function with tracing enabled
outputs = fn(inputs)
# Extract Arrays from outputs and store as list
output_arrays = _extract_arrays_from_pytree(outputs)
trace.outputs = output_arrays
make_untraced_pytree(inputs) # Detach inputs from the trace
# Handle outputs properly - make them untraced
make_untraced_pytree(output_arrays)
return trace
[docs]
def get_traced_nodes(self) -> list[Array]:
"""Get all nodes that belong to this trace in topological order."""
if not self._computed:
self._compute_trace()
return self.trace
def _compute_trace(self) -> None:
"""Compute the topological ordering of traced nodes."""
visited: set[Array] = set()
self.trace = []
for output in self.outputs:
self._dfs_visit(output, visited)
self._computed = True
def _dfs_visit(self, node: Array, visited: set[Array]) -> None:
"""DFS traversal to build topological ordering."""
if node in visited:
return
# Visit children first (post-order)
for arg in node.args:
self._dfs_visit(arg, visited)
# Add current node after visiting children
visited.add(node)
self.trace.append(node)
def __str__(self) -> str:
"""Return a JAX-like string representation of the trace."""
if not self._computed:
self._compute_trace()
from ..utils.formatting import format_shape_and_dtype
# Initialize name generator with a simple global counter
var_names = {}
alphabet = "abcdefghijklmnopqrstuvwxyz"
name_counter = 0
def _get_next_name():
nonlocal name_counter
if name_counter < len(alphabet):
# Single letters: a, b, c, ..., z
name = alphabet[name_counter]
else:
# Double letters: aa, ab, ac, ..., az, ba, bb, bc, ...
# Calculate indices for double letters
double_index = name_counter - len(alphabet)
first_letter = double_index // len(alphabet)
second_letter = double_index % len(alphabet)
name = alphabet[first_letter] + alphabet[second_letter]
name_counter += 1
return name
# Assign names to inputs first
input_vars = []
for inp in self.inputs:
var_name = _get_next_name()
var_names[id(inp)] = var_name
type_annotation = format_shape_and_dtype(inp)
input_vars.append(f"{var_name}:{type_annotation}")
# Single pass through trace: assign names and build equations
equations = []
for node in self.trace:
node_id = id(node)
# Skip if this is an input (already processed)
if node_id in var_names:
continue
# Assign a name to this node
var_name = _get_next_name()
var_names[node_id] = var_name
# Build the operation description
# print node name or the type if no name is set
if node.name:
op_name = node.name
else:
# check if the arg is a constant scalar, then we can simply show it as the arg directly
if (
isinstance(node, Array)
and node.shape == ()
and not node.batch_dims
and node.impl
):
# This is a constant scalar, show the raw value
op_name = str(node.to_numpy().item())
else:
# Fallback to the type or some default name
op_name = "external_const"
type_annotation = format_shape_and_dtype(node)
if node.args:
# Get argument variable names
arg_vars = []
for arg in node.args:
arg_id = id(arg)
if arg_id in var_names:
arg_vars.append(var_names[arg_id])
else:
# Array from external context - not part of the trace
arg_vars.append("external_const")
# Format the equation with type annotation
if len(arg_vars) == 1:
equation = (
f" {var_name}:{type_annotation} = {op_name} {arg_vars[0]}"
)
else:
args_joined = " ".join(arg_vars)
fmt_str = f" {var_name}:{type_annotation} = {op_name}"
equation = f"{fmt_str} {args_joined}"
else:
# Node with no arguments (constants, copies of external values, etc.)
equation = f" {var_name}:{type_annotation} = {op_name}"
equations.append(equation)
# Get output variable names
output_vars = []
for out in self.outputs:
out_id = id(out)
if out_id in var_names:
output_vars.append(var_names[out_id])
else:
output_vars.append("?")
# Format the final representation
input_sig = f"({', '.join(input_vars)})"
output_sig = (
f"({', '.join(output_vars)})" if len(output_vars) > 1 else output_vars[0]
)
result = f"{{ lambda {input_sig} ;\n"
result += " let\n"
for eq in equations:
result += f"{eq}\n"
result += f" in {output_sig} }}"
return result
def _cleanup_cotangents(traced_nodes: list[Array]) -> None:
"""Clean up cotangent values from traced nodes.
Args:
traced_nodes: List of traced nodes to clean up
"""
for node in traced_nodes:
node.cotangent = None
def _compute_pullback(
input_arrays: list[Array],
output_arrays: list[Array],
cotangent_arrays: list[Array],
) -> list[Array]:
"""Core reverse-mode gradient computation.
Args:
input_arrays: Input arrays to compute gradients for
output_arrays: Output arrays from the computation
cotangent_arrays: Cotangent vectors for outputs
Returns:
List of gradient arrays corresponding to inputs
"""
# Build computation trace
trace = Trace(input_arrays, output_arrays)
traced_nodes = trace.get_traced_nodes()
# Initialize output cotangents
for output, cotangent in zip(output_arrays, cotangent_arrays, strict=False):
output.cotangent = cotangent
try:
# Reverse-mode gradient computation
for node in reversed(traced_nodes):
if node.cotangent is None:
continue
if not node.args or node.vjp_rule is None:
continue
try:
arg_cotangents = node.vjp_rule(node.args, node.cotangent, node)
for arg, arg_cotangent in zip(node.args, arg_cotangents, strict=False):
if arg.cotangent is not None:
from ..ops.binary import add
arg.cotangent = add(arg.cotangent, arg_cotangent)
else:
arg.cotangent = arg_cotangent
if node not in input_arrays:
node.cotangent = None
except Exception as e:
raise RuntimeError(
f"VJP rule failed for operation '{node.name}': {e}"
) from e
# Collect gradients for input arrays
gradient_arrays = []
for inp in input_arrays:
if inp.cotangent is not None:
gradient_arrays.append(inp.cotangent)
else:
from ..ops.creation import zeros_like
gradient_arrays.append(zeros_like(inp))
return gradient_arrays
finally:
_cleanup_cotangents(traced_nodes)
def _reconstruct_gradient_structure(
gradient_arrays: list[Array],
inputs: Any,
) -> Any:
"""Reconstruct gradients in the same structure as inputs.
Args:
gradient_arrays: Flat list of gradient arrays
inputs: Original input structure to match
Returns:
Gradients with the same structure as inputs
"""
# Use the same flattening/unflattening logic as used for input extraction
input_arrays, structure = tree_flatten(inputs)
# Validate that we have the right number of gradients
if len(gradient_arrays) != len(input_arrays):
raise ValueError(
f"Gradient arrays length {len(gradient_arrays)} != "
f"input arrays length {len(input_arrays)}"
)
# Reconstruct the pytree structure with gradients
return tree_unflatten(structure, gradient_arrays)
[docs]
def pullback(
inputs: Any,
outputs: Any,
cotangents: Any,
) -> Any:
"""Compute vector-Jacobian product (reverse-mode autodiff).
Returns gradients in the exact same structure as inputs.
Args:
inputs: Input arrays or pytree of arrays
outputs: Output arrays or pytree of arrays
cotangents: Cotangent vectors or pytree of cotangents
Returns:
Gradients with respect to inputs, in the same structure as inputs
"""
# Extract arrays from pytree structures
input_arrays = _extract_arrays_from_pytree(inputs)
output_arrays = _extract_arrays_from_pytree(outputs)
cotangent_arrays = _extract_arrays_from_pytree(cotangents)
_validate_length_match(
cotangent_arrays, output_arrays, "Cotangent arrays", "output arrays"
)
# Core reverse-mode gradient computation
gradient_arrays = _compute_pullback(input_arrays, output_arrays, cotangent_arrays)
# Reconstruct gradients in input structure
gradients_in_input_structure = _reconstruct_gradient_structure(
gradient_arrays, inputs
)
return gradients_in_input_structure
def _compute_pushfwd(inputs, outputs, tangents, trace=None):
"""Compute JVP (forward-mode autodiff)."""
_validate_length_match(tangents, inputs, "Tangents", "inputs")
if trace is None:
trace = Trace(inputs, outputs)
traced_nodes = trace.get_traced_nodes()
for input_node, tangent in zip(inputs, tangents, strict=False):
input_node.tangent = tangent
for node in traced_nodes:
if node in inputs or not node.args or not node.jvp_rule:
continue
arg_tangents = []
for arg in node.args:
if arg.tangent is not None:
arg_tangents.append(arg.tangent)
else:
from ..ops.creation import zeros_like
arg_tangents.append(zeros_like(arg))
try:
node.tangent = node.jvp_rule(node.args, arg_tangents, node)
except Exception as e:
raise RuntimeError(
f"JVP rule failed for operation '{node.name}': {e}"
) from e
output_tangents = []
for out in outputs:
if out.tangent is not None:
output_tangents.append(out.tangent)
else:
from ..ops.creation import zeros_like
output_tangents.append(zeros_like(out))
return output_tangents
[docs]
def pushfwd(
inputs: Any,
outputs: Any,
tangents: Any,
) -> Any:
"""Compute Jacobian-vector product (forward-mode autodiff).
Returns output tangents in the same structure as outputs.
Args:
inputs: Input arrays or pytree of arrays
outputs: Output arrays or pytree of arrays
tangents: Tangent vectors or pytree of tangents
Returns:
Tangents with respect to outputs, in the same structure as outputs
"""
# Extract arrays from pytree structures
input_arrays = _extract_arrays_from_pytree(inputs)
output_arrays = _extract_arrays_from_pytree(outputs)
tangent_arrays = _extract_arrays_from_pytree(tangents)
_validate_length_match(
tangent_arrays, input_arrays, "Tangent arrays", "input arrays"
)
# Core forward-mode gradient computation
output_tangents = _compute_pushfwd(input_arrays, output_arrays, tangent_arrays)
# Reconstruct tangents in output structure
return tree_unflatten(tree_flatten(outputs)[1], output_tangents)
[docs]
def xpr(fn: Callable[..., Any], *primals) -> str:
"""Get a JAX-like string representation of the function's computation graph.
Args:
fn: Function to trace (should take positional arguments)
*primals: Positional arguments to the function (can be arbitrary pytrees)
Returns:
JAX-like string representation of the computation graph
Note:
This follows the same flexible API as vjp, jvp, and vmap:
- Accepts functions with any number of positional arguments
- For functions requiring keyword arguments, use functools.partial or lambda
"""
# Handle the input structure based on number of arguments (same as vjp)
if len(primals) == 1:
inputs_pytree = primals[0]
is_single_arg = True
else:
inputs_pytree = primals
is_single_arg = False
any_arg_traced = any(
getattr(arg, "traced", False)
for arg in _extract_arrays_from_pytree(inputs_pytree)
)
# Make traced copies of all inputs
traced_inputs_pytree = make_traced_pytree(inputs_pytree)
# Extract traced args based on the structure
traced_args = (traced_inputs_pytree,) if is_single_arg else traced_inputs_pytree
# Execute the function with traced inputs
outputs = fn(*traced_args)
# Extract output arrays for trace creation
output_arrays = _extract_arrays_from_pytree(outputs)
if not isinstance(output_arrays, list):
output_arrays = [output_arrays] if output_arrays is not None else []
# Extract input arrays for trace creation
input_arrays = _extract_arrays_from_pytree(traced_inputs_pytree)
if not isinstance(input_arrays, list):
input_arrays = [input_arrays] if input_arrays is not None else []
# Ensure we have proper Array lists (not Never)
if not input_arrays:
input_arrays = []
if not output_arrays:
output_arrays = []
# Create trace with the computation graph
trace = Trace(input_arrays, output_arrays) # type: ignore
# Make everything untraced before returning
# make_untraced_pytree(traced_inputs_pytree)
if not any_arg_traced:
make_untraced_pytree(outputs)
return str(trace)