Utils#

Core transformations for automatic differentiation and tracing.

nabla.transforms.utils.tree_flatten(tree)[source]#

Flatten a pytree into a list of Arrays and structure info.

Parameters:

tree (Any) – A pytree containing Arrays and other structures

Returns:

A tuple of (list of Array leaves, structure info for reconstruction)

Return type:

tuple[list[Array], Any]

nabla.transforms.utils.tree_unflatten(structure, leaves)[source]#

Reconstruct a pytree from structure info and list of Arrays.

Parameters:
  • structure (Any) – Structure info from tree_flatten

  • leaves (list[Array]) – List of Array values to place at Array positions

Returns:

Reconstructed pytree with the same structure as the original

Return type:

Any

nabla.transforms.utils.tree_map(func, tree)[source]#

Apply a function to all Array leaves in a pytree.

Parameters:
  • func (Callable[[Array], Array]) – Function to apply to each Array leaf

  • tree (Any) – Pytree containing Arrays

Returns:

Pytree with the same structure but transformed Arrays

Return type:

Any

nabla.transforms.utils.make_traced_pytree(tree)[source]#

Create shallow copies of arrays in a pytree and mark them as traced.

Parameters:

tree (Any) – Pytree containing Arrays to copy and mark as traced

Returns:

Pytree with the same structure but traced Arrays

Return type:

Any

nabla.transforms.utils.make_untraced_pytree(tree)[source]#

Disable tracing for arrays in a pytree by clearing their traced flag.

Parameters:

tree (Any) – Pytree containing Arrays to disable tracing for

nabla.transforms.utils.make_staged_pytree(args)[source]#

Enable staged execution for arrays to optimize performance.

Parameters:

args (list[Array]) – Arrays to enable staged execution for

nabla.transforms.utils.make_unstaged_pytree(args)[source]#

Disable staged execution for arrays.

Parameters:

args (list[Array]) – Arrays to disable staged execution for

class nabla.transforms.utils.Trace(inputs, outputs=None)[source]#

Bases: object

A simple trace container that holds the computation graph.

__init__(inputs, outputs=None)[source]#
classmethod trace_function(fn, inputs)[source]#

Create a trace by executing a function with tracing enabled.

This is the recommended way to create traces as it ensures proper tracing setup before function execution.

get_traced_nodes()[source]#

Get all nodes that belong to this trace in topological order.

nabla.transforms.utils.pullback(inputs, outputs, cotangents)[source]#

Compute vector-Jacobian product (reverse-mode autodiff).

Returns gradients in the exact same structure as inputs.

Parameters:
  • inputs (Any) – Input arrays or pytree of arrays

  • outputs (Any) – Output arrays or pytree of arrays

  • cotangents (Any) – Cotangent vectors or pytree of cotangents

Returns:

Gradients with respect to inputs, in the same structure as inputs

Return type:

Any

nabla.transforms.utils.pushfwd(inputs, outputs, tangents)[source]#

Compute Jacobian-vector product (forward-mode autodiff).

Returns output tangents in the same structure as outputs.

Parameters:
  • inputs (Any) – Input arrays or pytree of arrays

  • outputs (Any) – Output arrays or pytree of arrays

  • tangents (Any) – Tangent vectors or pytree of tangents

Returns:

Tangents with respect to outputs, in the same structure as outputs

Return type:

Any

nabla.transforms.utils.xpr(fn, *primals)[source]#

Get a JAX-like string representation of the function’s computation graph.

Parameters:
  • fn (Callable[[...], Any]) – Function to trace (should take positional arguments)

  • *primals – Positional arguments to the function (can be arbitrary pytrees)

Returns:

JAX-like string representation of the computation graph

Return type:

str

Note

This follows the same flexible API as vjp, jvp, and vmap: - Accepts functions with any number of positional arguments - For functions requiring keyword arguments, use functools.partial or lambda