Example 12: Custom Mojo Kernels#

Nabla lets you write custom operations in Mojo — a high-performance language that compiles to MAX graphs. This means you can:

  1. Write a Mojo kernel (elementwise, reduction, etc.)

  2. Wrap it as a Nabla Operation in Python

  3. Use it like any built-in op — including with nb.grad, nb.vmap, etc.

[3]:
from pathlib import Path
import numpy as np

import nabla as nb
from max.graph import TensorValue
from nabla.ops import UnaryOperation, call_custom_kernel

1. Writing a Mojo Kernel#

A Mojo kernel is a struct registered with @compiler.register("name"). The execute method receives input/output tensors and a device context.

Here’s a simple kernel that adds 1 to every element:

# kernels/custom_kernel.mojo
import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor, foreach
from utils.index import IndexList


@compiler.register("add_one")
struct AddOneKernel:
    @staticmethod
    fn execute[
        target: StaticString
    ](
        output: OutputTensor,
        x: InputTensor[dtype = output.dtype, rank = output.rank],
        ctx: DeviceContextPtr,
    ):
        @parameter
        fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, width]:
            return x.load[width](idx) + 1

        foreach[add_one, target=target](output, ctx)

Key points:

  • @compiler.register("add_one") — the name you’ll reference from Python

  • foreach auto-vectorizes the elementwise function across the tensor

  • InputTensor / OutputTensor handle memory layout automatically

  • The kernel directory also needs an __init__.mojo file (can be empty)

2. Python Operation Wrapper#

To use the Mojo kernel in Nabla, wrap it as a UnaryOperation subclass. The kernel method bridges Python tensors to the Mojo function:

[4]:
class AddOneOp(UnaryOperation):
    """Custom op: adds 1 to every element using a Mojo kernel."""

    @property
    def name(self) -> str:
        return "add_one"

    def kernel(self, args, kwargs):
        """Invoke the Mojo kernel via call_custom_kernel."""
        x = args[0]
        # Point to the directory containing the .mojo kernel files
        kernel_dir = Path("../../tests/mojo/kernels")
        result = call_custom_kernel("my_kernel", kernel_dir, x, x.type)
        return [result]  # Must return a list of TensorValues

    def _derivative(self, primals, output):
        """d(x+1)/dx = 1 — gradient passes through unchanged."""
        return nb.ones_like(primals)

# Instantiate the op (ops are stateless singletons)
_add_one_op = AddOneOp()

def add_one(x):
    """Add 1 to every element using our custom Mojo kernel."""
    return _add_one_op([x], {})[0]

print("AddOneOp registered")
AddOneOp registered

Important details:

Method

Purpose

name

Must match the @compiler.register("...") name

kernel(args, kwargs)

args is a list[TensorValue], kwargs is a dict

_derivative(primals, output)

Enables nb.grad — return \(\frac{\partial \text{out}}{\partial \text{in}}\)

For non-elementwise ops, override vjp_rule and jvp_rule directly instead of _derivative.

3. Using the Custom Op#

[6]:
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
y = add_one(x)
print(f"Input:  {x.to_numpy()}")
print(f"Output: {y.to_numpy()}")  # [2.0, 3.0, 4.0]
Input:  [1. 2. 3.]
Output: [2. 3. 4.]

4. Differentiating Through Custom Ops#

Because we implemented _derivative, Nabla can differentiate through our custom kernel just like any built-in op:

[7]:
def f(x):
    """A function using our custom kernel."""
    return nb.sum(add_one(x) * x)  # sum((x+1) * x) = sum(x² + x)

x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
grad_f = nb.grad(f)
g = grad_f(x)

print(f"f(x) = sum((x+1) * x)")
print(f"f'(x) = 2x + 1")
print(f"Input:    {x.to_numpy()}")
print(f"Gradient: {g.to_numpy()}")  # [3.0, 5.0, 7.0]
f(x) = sum((x+1) * x)
f'(x) = 2x + 1
Input:    [1. 2. 3.]
Gradient: [3. 5. 7.]

Key takeaways:

  • Custom Mojo kernels compile automatically via the MAX engine

  • Wrap kernels as UnaryOperation (or Operation) subclasses

  • call_custom_kernel handles kernel loading and invocation

  • Implement _derivative (elementwise) or vjp_rule/jvp_rule for autodiff

  • Custom ops compose with all Nabla transforms (grad, vmap, compile)