Example 11: Custom Mojo Kernels#
Nabla lets you write custom operations in Mojo — a high-performance language that compiles to MAX graphs. This means you can:
Write a Mojo kernel (elementwise, reduction, etc.)
Wrap it as a Nabla
Operationin PythonUse it like any built-in op — including with
nb.grad,nb.vmap, etc.
Requirements: The modular package must be installed (pip install modular). Mojo kernels are compiled automatically by the MAX engine at graph execution time.
┌──────────────┐ ┌──────────────────┐ ┌────────────────┐
│ Mojo kernel │ ──▶ │ Python Operation │ ──▶ │ Nabla Tensor │
│ (.mojo file)│ │ (UnaryOperation) │ │ computation │
└──────────────┘ └──────────────────┘ └────────────────┘
[1]:
from pathlib import Path
import numpy as np
import nabla as nb
# Check if MAX/Mojo is available
try:
from max.graph import TensorValue
from nabla.ops import UnaryOperation, call_custom_kernel
HAS_MOJO = True
print("MAX SDK available — custom kernels enabled")
except ImportError:
HAS_MOJO = False
print("MAX SDK not installed — showing code patterns only")
MAX SDK available — custom kernels enabled
1. Writing a Mojo Kernel#
A Mojo kernel is a struct registered with @compiler.register("name"). The execute method receives input/output tensors and a device context.
Here’s a simple kernel that adds 1 to every element:
# kernels/custom_kernel.mojo
import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor, foreach
from utils.index import IndexList
@compiler.register("add_one")
struct AddOneKernel:
@staticmethod
fn execute[
target: StaticString
](
output: OutputTensor,
x: InputTensor[dtype = output.dtype, rank = output.rank],
ctx: DeviceContextPtr,
):
@parameter
fn add_one[width: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, width]:
return x.load[width](idx) + 1
foreach[add_one, target=target](output, ctx)
Key points:
@compiler.register("add_one")— the name you’ll reference from Pythonforeachauto-vectorizes the elementwise function across the tensorInputTensor/OutputTensorhandle memory layout automaticallyThe kernel directory also needs an
__init__.mojofile (can be empty)
2. Python Operation Wrapper#
To use the Mojo kernel in Nabla, wrap it as a UnaryOperation subclass. The kernel method bridges Python tensors to the Mojo function:
[2]:
if HAS_MOJO:
class AddOneOp(UnaryOperation):
"""Custom op: adds 1 to every element using a Mojo kernel."""
@property
def name(self) -> str:
return "add_one"
def kernel(self, args, kwargs):
"""Invoke the Mojo kernel via call_custom_kernel."""
x = args[0]
# Point to the directory containing the .mojo kernel files
kernel_dir = Path("../../tests/mojo/kernels")
result = call_custom_kernel("my_kernel", kernel_dir, x, x.type)
return [result] # Must return a list of TensorValues
def _derivative(self, primals, output):
"""d(x+1)/dx = 1 — gradient passes through unchanged."""
return nb.ones_like(primals)
# Instantiate the op (ops are stateless singletons)
_add_one_op = AddOneOp()
def add_one(x):
"""Add 1 to every element using our custom Mojo kernel."""
return _add_one_op([x], {})[0]
print("AddOneOp registered")
AddOneOp registered
Important details:
Method |
Purpose |
|---|---|
|
Must match the |
|
|
|
Enables |
For non-elementwise ops, override vjp_rule and jvp_rule directly instead of _derivative.
3. Using the Custom Op#
[3]:
if HAS_MOJO:
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
y = add_one(x)
print(f"Input: {x.to_numpy()}")
print(f"Output: {y.to_numpy()}") # [2.0, 3.0, 4.0]
else:
print("Skipped — MAX SDK not available")
print("Expected: add_one([1.0, 2.0, 3.0]) → [2.0, 3.0, 4.0]")
Input: [1. 2. 3.]
Output: [2. 3. 4.]
4. Differentiating Through Custom Ops#
Because we implemented _derivative, Nabla can differentiate through our custom kernel just like any built-in op:
[4]:
if HAS_MOJO:
def f(x):
"""A function using our custom kernel."""
return nb.sum(add_one(x) * x) # sum((x+1) * x) = sum(x² + x)
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
grad_f = nb.grad(f)
g = grad_f(x)
print(f"f(x) = sum((x+1) * x)")
print(f"f'(x) = 2x + 1")
print(f"Input: {x.to_numpy()}")
print(f"Gradient: {g.to_numpy()}") # [3.0, 5.0, 7.0]
else:
print("Skipped — expected gradient of sum((x+1)*x) = 2x+1 = [3.0, 5.0, 7.0]")
f(x) = sum((x+1) * x)
f'(x) = 2x + 1
Input: [1. 2. 3.]
Gradient: [3. 5. 7.]
5. Writing More Complex Kernels#
The foreach pattern handles elementwise ops, but you can write any computation in Mojo. Here’s a sketch of a fused multiply-add kernel:
@compiler.register("fused_mul_add")
struct FusedMulAdd:
@staticmethod
fn execute[
target: StaticString
](
output: OutputTensor,
a: InputTensor[dtype = output.dtype, rank = output.rank],
b: InputTensor[dtype = output.dtype, rank = output.rank],
c: InputTensor[dtype = output.dtype, rank = output.rank],
ctx: DeviceContextPtr,
):
# output = a * b + c
@parameter
fn fma[width: Int](idx: IndexList[a.rank]) -> SIMD[a.dtype, width]:
return a.load[width](idx) * b.load[width](idx) + c.load[width](idx)
foreach[fma, target=target](output, ctx)
For multi-input ops, extend Operation instead of UnaryOperation and implement vjp_rule / jvp_rule directly.
Key takeaways:
Custom Mojo kernels compile automatically via the MAX engine
Wrap kernels as
UnaryOperation(orOperation) subclassescall_custom_kernelhandles kernel loading and invocationImplement
_derivative(elementwise) orvjp_rule/jvp_rulefor autodiffCustom ops compose with all Nabla transforms (
grad,vmap,compile)