Example 11: Custom Mojo Kernels#

Nabla lets you write custom operations in Mojo — a high-performance language that compiles to MAX graphs. This means you can:

  1. Write a Mojo kernel (elementwise, reduction, etc.)

  2. Wrap it as a Nabla Operation in Python

  3. Use it like any built-in op — including with nb.grad, nb.vmap, etc.

Requirements: The modular package must be installed (pip install modular). Mojo kernels are compiled automatically by the MAX engine at graph execution time.

┌──────────────┐     ┌──────────────────┐     ┌────────────────┐
│  Mojo kernel │ ──▶ │ Python Operation │ ──▶ │  Nabla Tensor  │
│  (.mojo file)│     │ (UnaryOperation) │     │  computation   │
└──────────────┘     └──────────────────┘     └────────────────┘
[1]:
from pathlib import Path
import numpy as np

import nabla as nb

# Check if MAX/Mojo is available
try:
    from max.graph import TensorValue
    from nabla.ops import UnaryOperation, call_custom_kernel
    HAS_MOJO = True
    print("MAX SDK available — custom kernels enabled")
except ImportError:
    HAS_MOJO = False
    print("MAX SDK not installed — showing code patterns only")
MAX SDK available — custom kernels enabled

1. Writing a Mojo Kernel#

A Mojo kernel is a struct registered with @compiler.register("name"). The execute method receives input/output tensors and a device context.

Here’s a simple kernel that adds 1 to every element:

# kernels/custom_kernel.mojo
import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor, foreach
from utils.index import IndexList


@compiler.register("add_one")
struct AddOneKernel:
    @staticmethod
    fn execute[
        target: StaticString
    ](
        output: OutputTensor,
        x: InputTensor[dtype = output.dtype, rank = output.rank],
        ctx: DeviceContextPtr,
    ):
        @parameter
        fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, width]:
            return x.load[width](idx) + 1

        foreach[add_one, target=target](output, ctx)

Key points:

  • @compiler.register("add_one") — the name you’ll reference from Python

  • foreach auto-vectorizes the elementwise function across the tensor

  • InputTensor / OutputTensor handle memory layout automatically

  • The kernel directory also needs an __init__.mojo file (can be empty)

2. Python Operation Wrapper#

To use the Mojo kernel in Nabla, wrap it as a UnaryOperation subclass. The kernel method bridges Python tensors to the Mojo function:

[2]:
if HAS_MOJO:
    class AddOneOp(UnaryOperation):
        """Custom op: adds 1 to every element using a Mojo kernel."""

        @property
        def name(self) -> str:
            return "add_one"

        def kernel(self, args, kwargs):
            """Invoke the Mojo kernel via call_custom_kernel."""
            x = args[0]
            # Point to the directory containing the .mojo kernel files
            kernel_dir = Path("../../tests/mojo/kernels")
            result = call_custom_kernel("my_kernel", kernel_dir, x, x.type)
            return [result]  # Must return a list of TensorValues

        def _derivative(self, primals, output):
            """d(x+1)/dx = 1 — gradient passes through unchanged."""
            return nb.ones_like(primals)

    # Instantiate the op (ops are stateless singletons)
    _add_one_op = AddOneOp()

    def add_one(x):
        """Add 1 to every element using our custom Mojo kernel."""
        return _add_one_op([x], {})[0]

    print("AddOneOp registered")
AddOneOp registered

Important details:

Method

Purpose

name

Must match the @compiler.register("...") name

kernel(args, kwargs)

args is a list[TensorValue], kwargs is a dict

_derivative(primals, output)

Enables nb.grad — return \(\frac{\partial \text{out}}{\partial \text{in}}\)

For non-elementwise ops, override vjp_rule and jvp_rule directly instead of _derivative.

3. Using the Custom Op#

[3]:
if HAS_MOJO:
    x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
    y = add_one(x)
    print(f"Input:  {x.to_numpy()}")
    print(f"Output: {y.to_numpy()}")  # [2.0, 3.0, 4.0]
else:
    print("Skipped — MAX SDK not available")
    print("Expected: add_one([1.0, 2.0, 3.0]) → [2.0, 3.0, 4.0]")
Input:  [1. 2. 3.]
Output: [2. 3. 4.]

4. Differentiating Through Custom Ops#

Because we implemented _derivative, Nabla can differentiate through our custom kernel just like any built-in op:

[4]:
if HAS_MOJO:
    def f(x):
        """A function using our custom kernel."""
        return nb.sum(add_one(x) * x)  # sum((x+1) * x) = sum(x² + x)

    x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
    grad_f = nb.grad(f)
    g = grad_f(x)

    print(f"f(x) = sum((x+1) * x)")
    print(f"f'(x) = 2x + 1")
    print(f"Input:    {x.to_numpy()}")
    print(f"Gradient: {g.to_numpy()}")  # [3.0, 5.0, 7.0]
else:
    print("Skipped — expected gradient of sum((x+1)*x) = 2x+1 = [3.0, 5.0, 7.0]")
f(x) = sum((x+1) * x)
f'(x) = 2x + 1
Input:    [1. 2. 3.]
Gradient: [3. 5. 7.]

5. Writing More Complex Kernels#

The foreach pattern handles elementwise ops, but you can write any computation in Mojo. Here’s a sketch of a fused multiply-add kernel:

@compiler.register("fused_mul_add")
struct FusedMulAdd:
    @staticmethod
    fn execute[
        target: StaticString
    ](
        output: OutputTensor,
        a: InputTensor[dtype = output.dtype, rank = output.rank],
        b: InputTensor[dtype = output.dtype, rank = output.rank],
        c: InputTensor[dtype = output.dtype, rank = output.rank],
        ctx: DeviceContextPtr,
    ):
        # output = a * b + c
        @parameter
        fn fma[width: Int](idx: IndexList[a.rank]) -> SIMD[a.dtype, width]:
            return a.load[width](idx) * b.load[width](idx) + c.load[width](idx)

        foreach[fma, target=target](output, ctx)

For multi-input ops, extend Operation instead of UnaryOperation and implement vjp_rule / jvp_rule directly.


Key takeaways:

  • Custom Mojo kernels compile automatically via the MAX engine

  • Wrap kernels as UnaryOperation (or Operation) subclasses

  • call_custom_kernel handles kernel loading and invocation

  • Implement _derivative (elementwise) or vjp_rule/jvp_rule for autodiff

  • Custom ops compose with all Nabla transforms (grad, vmap, compile)