Example 12: Custom Mojo Kernels#
Nabla lets you write custom operations in Mojo — a high-performance language that compiles to MAX graphs. This means you can:
Write a Mojo kernel (elementwise, reduction, etc.)
Wrap it as a Nabla
Operationin PythonUse it like any built-in op — including with
nb.grad,nb.vmap, etc.
[3]:
from pathlib import Path
import numpy as np
import nabla as nb
from max.graph import TensorValue
from nabla.ops import UnaryOperation, call_custom_kernel
1. Writing a Mojo Kernel#
A Mojo kernel is a struct registered with @compiler.register("name"). The execute method receives input/output tensors and a device context.
Here’s a simple kernel that adds 1 to every element:
# kernels/custom_kernel.mojo
import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor, foreach
from utils.index import IndexList
@compiler.register("add_one")
struct AddOneKernel:
@staticmethod
fn execute[
target: StaticString
](
output: OutputTensor,
x: InputTensor[dtype = output.dtype, rank = output.rank],
ctx: DeviceContextPtr,
):
@parameter
fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, width]:
return x.load[width](idx) + 1
foreach[add_one, target=target](output, ctx)
Key points:
@compiler.register("add_one")— the name you’ll reference from Pythonforeachauto-vectorizes the elementwise function across the tensorInputTensor/OutputTensorhandle memory layout automaticallyThe kernel directory also needs an
__init__.mojofile (can be empty)
2. Python Operation Wrapper#
To use the Mojo kernel in Nabla, wrap it as a UnaryOperation subclass. The kernel method bridges Python tensors to the Mojo function:
[4]:
class AddOneOp(UnaryOperation):
"""Custom op: adds 1 to every element using a Mojo kernel."""
@property
def name(self) -> str:
return "add_one"
def kernel(self, args, kwargs):
"""Invoke the Mojo kernel via call_custom_kernel."""
x = args[0]
# Point to the directory containing the .mojo kernel files
kernel_dir = Path("../../tests/mojo/kernels")
result = call_custom_kernel("my_kernel", kernel_dir, x, x.type)
return [result] # Must return a list of TensorValues
def _derivative(self, primals, output):
"""d(x+1)/dx = 1 — gradient passes through unchanged."""
return nb.ones_like(primals)
# Instantiate the op (ops are stateless singletons)
_add_one_op = AddOneOp()
def add_one(x):
"""Add 1 to every element using our custom Mojo kernel."""
return _add_one_op([x], {})[0]
print("AddOneOp registered")
AddOneOp registered
Important details:
Method |
Purpose |
|---|---|
|
Must match the |
|
|
|
Enables |
For non-elementwise ops, override vjp_rule and jvp_rule directly instead of _derivative.
3. Using the Custom Op#
[6]:
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
y = add_one(x)
print(f"Input: {x.to_numpy()}")
print(f"Output: {y.to_numpy()}") # [2.0, 3.0, 4.0]
Input: [1. 2. 3.]
Output: [2. 3. 4.]
4. Differentiating Through Custom Ops#
Because we implemented _derivative, Nabla can differentiate through our custom kernel just like any built-in op:
[7]:
def f(x):
"""A function using our custom kernel."""
return nb.sum(add_one(x) * x) # sum((x+1) * x) = sum(x² + x)
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
grad_f = nb.grad(f)
g = grad_f(x)
print(f"f(x) = sum((x+1) * x)")
print(f"f'(x) = 2x + 1")
print(f"Input: {x.to_numpy()}")
print(f"Gradient: {g.to_numpy()}") # [3.0, 5.0, 7.0]
f(x) = sum((x+1) * x)
f'(x) = 2x + 1
Input: [1. 2. 3.]
Gradient: [3. 5. 7.]
Key takeaways:
Custom Mojo kernels compile automatically via the MAX engine
Wrap kernels as
UnaryOperation(orOperation) subclassescall_custom_kernelhandles kernel loading and invocationImplement
_derivative(elementwise) orvjp_rule/jvp_rulefor autodiffCustom ops compose with all Nabla transforms (
grad,vmap,compile)