vmap#

Signature#

nabla.vmap(func: collections.abc.Callable | None = None, in_axes: Union[int, NoneType, list, tuple] = 0, out_axes: Union[int, NoneType, list, tuple] = 0) -> collections.abc.Callable[..., typing.Any]

Description#

Creates a function that maps a function over axes of pytrees.

vmap is a transformation that converts a function designed for single data points into a function that can operate on batches of data points. It achieves this by adding a batch dimension to all operations within the function, enabling efficient, parallel execution.

Parameters#

func: The function to be vectorized. It should be written as if it operates on a single example. in_axes: Specifies which axis of the input(s) to map over. Can be an integer, None, or a pytree of these values. None indicates that the corresponding input should be broadcast. out_axes: Specifies where to place the batch axis in the output(s).

Returns#

A vectorized function with the same input/output structure as func.

Examples#

import nabla as nb

# Vectorize function over batch dimension
def dot_product(a, b):
    return nb.sum(a * b)

# Vectorize over first dimension
batch_dot = nb.vmap(dot_product, in_axes=(0, 0))

a_batch = nb.randn((10, 5))  # 10 vectors of length 5
b_batch = nb.randn((10, 5))
results = batch_dot(a_batch, b_batch)  # 10 dot products

See Also#

  • jit - Just-in-time compilation

  • vjp, jvp - Automatic differentiation