Jit

Contents

Jit#

nabla.transforms.jit.jit(func=None, static=True, show_graph=False)[source]#

Just-in-time compile a function for performance optimization. This can be used as a function call like jit(func) or as a decorator @jit.

Parameters:

func (Callable[[...], Any] | None) – Function to optimize with JIT compilation (should take positional arguments)

Returns:

JIT-compiled function with optimized execution

Return type:

Callable[[…], Any]

Note

This follows JAX’s jit API:

  • Only accepts positional arguments

  • For functions requiring keyword arguments, use functools.partial or lambda

  • Supports both list-style (legacy) and unpacked arguments style (JAX-like)

Example

As a function call:

fast_func = jit(my_func)

As a decorator:

@jit
def my_func(x):
    return x * 2
nabla.transforms.jit.djit(func=None, show_graph=False)[source]#

Dynamic JIT compile a function for performance optimization. This can be used as a function call like djit(func) or as a decorator @djit.

Parameters:

func (Callable[[...], Any] | None) – Function to optimize with JIT compilation (should take positional arguments)

Returns:

JIT-compiled function with optimized execution

Return type:

Callable[[…], Any]

Note

This follows JAX’s jit API:

  • Only accepts positional arguments

  • For functions requiring keyword arguments, use functools.partial or lambda

  • Supports both list-style (legacy) and unpacked arguments style (JAX-like)