Utils#
Core transformations for automatic differentiation and tracing.
- nabla.transforms.utils.tree_flatten(tree)[source]#
Flatten a pytree into a list of Arrays and structure info.
- nabla.transforms.utils.tree_unflatten(structure, leaves)[source]#
Reconstruct a pytree from structure info and list of Arrays.
- nabla.transforms.utils.tree_map(func, tree)[source]#
Apply a function to all Array leaves in a pytree.
- nabla.transforms.utils.make_traced_pytree(tree)[source]#
Create shallow copies of arrays in a pytree and mark them as traced.
- nabla.transforms.utils.make_untraced_pytree(tree)[source]#
Disable tracing for arrays in a pytree by clearing their traced flag.
- Parameters:
tree (Any) – Pytree containing Arrays to disable tracing for
- nabla.transforms.utils.make_staged_pytree(args)[source]#
Enable staged execution for arrays to optimize performance.
- class nabla.transforms.utils.Trace(inputs, outputs=None)[source]#
Bases:
object
A simple trace container that holds the computation graph.
- nabla.transforms.utils.pullback(inputs, outputs, cotangents)[source]#
Compute vector-Jacobian product (reverse-mode autodiff).
Returns gradients in the exact same structure as inputs.
- nabla.transforms.utils.pushfwd(inputs, outputs, tangents)[source]#
Compute Jacobian-vector product (forward-mode autodiff).
Returns output tangents in the same structure as outputs.
- nabla.transforms.utils.xpr(fn, *primals)[source]#
Get a JAX-like string representation of the function’s computation graph.
- Parameters:
- Returns:
JAX-like string representation of the computation graph
- Return type:
Note
This follows the same flexible API as vjp, jvp, and vmap: - Accepts functions with any number of positional arguments - For functions requiring keyword arguments, use functools.partial or lambda