PyTree Utilities#

tree_flatten#

def tree_flatten(tree: 'Any', is_leaf: 'Callable[[Any], bool] | None' = None) -> 'tuple[list[Any], PyTreeDef]':

Flatten a pytree into leaves and structure.


tree_unflatten#

def tree_unflatten(treedef: 'PyTreeDef', leaves: 'list[Any]') -> 'Any':

Reconstruct a pytree from structure info and leaves.


tree_map#

def tree_map(fn: 'Callable[..., Any]', tree: 'Any', *rest: 'Any', is_leaf: 'Callable[[Any], bool] | None' = None) -> 'Any':

Apply a function to every leaf of a pytree.


tree_leaves#

def tree_leaves(tree: 'Any', is_leaf: 'Callable[[Any], bool] | None' = None) -> 'list[Any]':

Get all leaves from a pytree (optimized version - doesn’t build treedef).


tree_structure#

def tree_structure(tree: 'Any', is_leaf: 'Callable[[Any], bool] | None' = None) -> 'PyTreeDef':

Get structure info from a pytree.


PyTreeDef#

class PyTreeDef(kind: 'int', meta: 'Any', children: 'tuple[PyTreeDef, ...]', num_leaves: 'int') -> 'None':

Immutable definition of a pytree’s structure.


register_pytree_node#

def register_pytree_node(cls: 'type', flatten_fn: 'Callable[[Any], tuple[list[Any], Any]]', unflatten_fn: 'Callable[[Any, list[Any]], Any]') -> 'None':

Register a custom class as a pytree container node.


tensor_leaves#

def tensor_leaves(tree: 'Any') -> 'list[Tensor]':

Get only Tensor leaves from a pytree.


traced#

def traced(tree: 'Any') -> 'Any':

Mark all tensors in a pytree as traced.


untraced#

def untraced(tree: 'Any') -> 'Any':

Mark all tensors in a pytree as untraced.


with_batch_dims#

def with_batch_dims(tree: 'Any', delta: 'int') -> 'Any':

Adjust batch_dims on all tensors in a pytree.