Vmap#
- nabla.transforms.vmap.vmap(func=None, in_axes=0, out_axes=0)[source]#
Creates a function that maps a function over axes of pytrees.
vmap is a transformation that converts a function designed for single data points into a function that can operate on batches of data points. It achieves this by adding a batch dimension to all operations within the function, enabling efficient, parallel execution.
- Parameters:
func (Callable | None) – The function to be vectorized. It should be written as if it operates on a single example.
in_axes (int | None | list | tuple) – Specifies which axis of the input(s) to map over. Can be an integer, None, or a pytree of these values. None indicates that the corresponding input should be broadcast.
out_axes (int | None | list | tuple) – Specifies where to place the batch axis in the output(s).
- Returns:
A vectorized function with the same input/output structure as func.
- Return type: