Vmap

Contents

Vmap#

nabla.transforms.vmap.vmap(func=None, in_axes=0, out_axes=0)[source]#

Creates a function that maps a function over axes of pytrees.

vmap is a transformation that converts a function designed for single data points into a function that can operate on batches of data points. It achieves this by adding a batch dimension to all operations within the function, enabling efficient, parallel execution.

Parameters:
  • func (Callable | None) – The function to be vectorized. It should be written as if it operates on a single example.

  • in_axes (int | None | list | tuple) – Specifies which axis of the input(s) to map over. Can be an integer, None, or a pytree of these values. None indicates that the corresponding input should be broadcast.

  • out_axes (int | None | list | tuple) – Specifies where to place the batch axis in the output(s).

Returns:

A vectorized function with the same input/output structure as func.

Return type:

Callable[[…], Any]