Function Transformations

Function Transformations#

Function transformations for compilation, vectorization, and automatic differentiation.

Quick Reference#

djit#

nabla.djit(func: Optional[collections.abc.Callable[..., Any]] = None, show_graph: bool = False) -> collections.abc.Callable[..., typing.Any]

Nabla operation: djit

grad#

nabla.grad(fun: collections.abc.Callable | None = None, argnums: int | collections.abc.Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: collections.abc.Sequence = (), mode: str = 'reverse') -> collections.abc.Callable[..., typing.Any]

Automatic differentiation to compute gradients.

jacfwd#

nabla.jacfwd(func: collections.abc.Callable[..., typing.Any], argnums: int | tuple[int, ...] | list[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> collections.abc.Callable[..., typing.Any]

Nabla operation: jacfwd

jacrev#

nabla.jacrev(func: collections.abc.Callable[..., typing.Any], argnums: int | tuple[int, ...] | list[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> collections.abc.Callable[..., typing.Any]

Nabla operation: jacrev

jit#

nabla.jit(func: Optional[collections.abc.Callable[..., Any]] = None, static: bool = True, show_graph: bool = False) -> collections.abc.Callable[..., typing.Any]

Just-in-time compilation for performance optimization.

jvp#

nabla.jvp(func: collections.abc.Callable[..., typing.Any], primals, tangents, has_aux: bool = False) -> tuple[typing.Any, typing.Any] | tuple[typing.Any, typing.Any, typing.Any]

Jacobian-vector product for forward-mode automatic differentiation.

vjp#

nabla.vjp(func: collections.abc.Callable[..., typing.Any], *primals, has_aux: bool = False) -> tuple[typing.Any, collections.abc.Callable] | tuple[typing.Any, collections.abc.Callable, typing.Any]

Vector-Jacobian product for reverse-mode automatic differentiation.

vmap#

nabla.vmap(func: collections.abc.Callable | None = None, in_axes: Union[int, NoneType, list, tuple] = 0, out_axes: Union[int, NoneType, list, tuple] = 0) -> collections.abc.Callable[..., typing.Any]

Vectorization transformation for batching operations.

xpr#

nabla.xpr(fn: 'Callable[..., Any]', *primals) -> 'str'

Create expression graphs for deferred execution.