Example 4: Transforms and @nb.compile#
Nabla’s transforms are higher-order functions that take a function and return a new function with modified behavior. They are fully composable and work with any Nabla operation, including nn.Modules.
Transform |
What it does |
|---|---|
|
Auto-vectorize over a batch dimension |
|
Compute gradients (reverse-mode) |
|
Full Jacobian via reverse-mode |
|
Full Jacobian via forward-mode |
|
Compile computation graph to MAX graph |
[1]:
import numpy as np
import nabla as nb
print("Nabla Transforms & Compile Example")
Nabla Transforms & Compile Example
1. vmap — Automatic Vectorization#
vmap transforms a function that operates on a single example into one that operates on a batch — without writing any batching logic yourself.
[2]:
def single_dot(x, y):
"""Dot product of two vectors (no batch dimension)."""
return nb.reduce_sum(x * y)
# Without vmap: manual loop
x_batch = nb.uniform((5, 3))
y_batch = nb.uniform((5, 3))
# With vmap: automatic vectorization!
batched_dot = nb.vmap(single_dot, in_axes=(0, 0))
result = batched_dot(x_batch, y_batch)
print(f"Batched dot products (5 pairs of 3D vectors):")
print(result)
print(f"Shape: {result.shape}")
Batched dot products (5 pairs of 3D vectors):
Tensor([0.9671 1.0155 1.0364 0.5181 0.8411] : f32[5])
Shape: [Dim(5)]
in_axes and out_axes#
in_axes controls which axis of each argument is the batch axis. out_axes controls where to place the batch axis in the output. Use None for arguments that should be broadcast (not batched).
[3]:
def weighted_sum(x, w):
"""Weighted sum: w * x, summed."""
return nb.reduce_sum(w * x)
# x is batched (axis 0), w is shared across the batch
batch_fn = nb.vmap(weighted_sum, in_axes=(0, None))
x_batch = nb.uniform((4, 3))
w = nb.Tensor.from_dlpack(np.array([1.0, 2.0, 3.0], dtype=np.float32))
result = batch_fn(x_batch, w)
print(f"Batched weighted sum (shared weights):")
print(result)
print(f"Shape: {result.shape}")
Batched weighted sum (shared weights):
Tensor([2.8038 5.0148 2.5579 2.3182] : f32[4])
Shape: [Dim(4)]
2. vmap of grad — Per-Example Gradients#
Composing vmap with grad gives per-example gradients — something that’s difficult to do efficiently in most frameworks.
[4]:
def per_sample_loss(x, w):
"""Loss for a single sample: (w @ x)^2."""
return nb.reduce_sum(w * x) ** 2
# grad of the loss w.r.t. w for a single sample
grad_single = nb.grad(per_sample_loss, argnums=1)
# vmap over samples — per-example gradients!
per_example_grad = nb.vmap(grad_single, in_axes=(0, None))
x_batch = nb.Tensor.from_dlpack(
np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=np.float32)
)
w = nb.Tensor.from_dlpack(np.array([2.0, 3.0], dtype=np.float32))
grads = per_example_grad(x_batch, w)
print("Per-example gradients (3 samples, 2 weights):")
print(grads)
print(f"Shape: {grads.shape}")
Per-example gradients (3 samples, 2 weights):
Tensor(
[[ 4. 0.]
[ 0. 6.]
[10. 10.]] : f32[3,2]
)
Shape: [Dim(3), Dim(2)]
3. jacrev and jacfwd — Full Jacobians#
Recall from Example 2: jacrev and jacfwd compute full Jacobian matrices. Here we show them applied to a more interesting function.
[5]:
def neural_layer(x):
"""A simple neural network layer: tanh(xW + b)."""
W = nb.Tensor.from_dlpack(
np.array([[1.0, 0.3, -0.2], [-0.5, 0.8, 0.6]], dtype=np.float32)
)
b = nb.Tensor.from_dlpack(np.array([0.1, -0.1, 0.2], dtype=np.float32))
return nb.tanh(x @ W + b)
x = nb.Tensor.from_dlpack(np.array([1.0, 0.5], dtype=np.float32))
J_rev = nb.jacrev(neural_layer)(x)
J_fwd = nb.jacfwd(neural_layer)(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {neural_layer(x).shape}")
print(f"\nJacobian via jacrev (shape {J_rev.shape}):")
print(J_rev)
print(f"\nJacobian via jacfwd (shape {J_fwd.shape}):")
print(J_fwd)
Input shape: [Dim(2)]
Output shape: [Dim(3)]
Jacobian via jacrev (shape [Dim(3), Dim(2)]):
Tensor(
[[ 0.5224 -0.2612]
[ 0.2135 0.5693]
[-0.183 0.5491]] : f32[3,2]
)
Jacobian via jacfwd (shape [Dim(3), Dim(2)]):
Tensor(
[[ 0.5224 -0.2612]
[ 0.2135 0.5693]
[-0.183 0.5491]] : f32[3,2]
)
4. Composing Jacobians — Hessians#
Since transforms compose, we can compute Hessians by nesting:
[6]:
def energy(x):
"""Energy function: E(x) = 0.5 * x^T A x where A = [[2, 1], [1, 3]]."""
A = nb.Tensor.from_dlpack(
np.array([[2.0, 1.0], [1.0, 3.0]], dtype=np.float32)
)
return 0.5 * nb.reduce_sum(x * (A @ x))
x = nb.Tensor.from_dlpack(np.array([1.0, 2.0], dtype=np.float32))
print(f"E(x) = 0.5 * x^T @ A @ x, where A = [[2,1],[1,3]]")
print(f"E([1,2]) = {energy(x)}")
print(f"Gradient: {nb.grad(energy)(x)}")
print(f" (should be Ax = [4, 7])")
H = nb.jacfwd(nb.grad(energy))(x)
print(f"\nHessian (should be A = [[2,1],[1,3]]):")
print(H)
E(x) = 0.5 * x^T @ A @ x, where A = [[2,1],[1,3]]
E([1,2]) = Tensor(9. : f32[])
Gradient: Tensor([4. 7.] : f32[2])
(should be Ax = [4, 7])
Hessian (should be A = [[2,1],[1,3]]):
Tensor(
[[2. 1.]
[1. 3.]] : f32[2,2]
)
5. @nb.compile — Graph Compilation#
@nb.compile traces a function, captures its computation graph, and compiles it into an optimized MAX graph. Subsequent calls with the same tensor shapes/dtypes hit a cache — dramatically speeding up execution.
[7]:
import time
def slow_fn(x, y):
"""A function with many operations."""
for _ in range(5):
x = nb.relu(x @ y + x)
return nb.reduce_sum(x)
@nb.compile
def fast_fn(x, y):
"""Same function, but compiled."""
for _ in range(5):
x = nb.relu(x @ y + x)
return nb.reduce_sum(x)
x = nb.uniform((32, 32))
y = nb.uniform((32, 32))
Benchmarking Eager vs Compiled#
The first call to a compiled function triggers tracing and compilation (warmup). Subsequent calls use the cached compiled graph — skipping Python overhead entirely:
[8]:
# Warmup compiled version (first call traces and compiles)
_ = fast_fn(x, y)
# Benchmark eager
start = time.perf_counter()
for _ in range(20):
_ = slow_fn(x, y)
eager_time = time.perf_counter() - start
# Benchmark compiled
start = time.perf_counter()
for _ in range(20):
_ = fast_fn(x, y)
compiled_time = time.perf_counter() - start
print(f"Eager: {eager_time:.4f}s")
print(f"Compiled: {compiled_time:.4f}s")
print(f"Speedup: {eager_time / max(compiled_time, 1e-9):.1f}x")
Eager: 0.0183s
Compiled: 0.0034s
Speedup: 5.4x
6. Compiled Training Loop#
The real power of @nb.compile is compiling entire training steps. When used with value_and_grad and adamw_update, the forward pass, backward pass, and optimizer step are all fused into a single compiled graph.
[9]:
class TinyMLP(nb.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nb.nn.Linear(4, 16)
self.fc2 = nb.nn.Linear(16, 1)
def forward(self, x):
return self.fc2(nb.relu(self.fc1(x)))
Compiling the Full Training Step#
The @nb.compile decorator compiles forward + backward + optimizer update into a single fused graph — the entire training step runs without Python overhead:
[10]:
def my_loss_fn(model, x, y):
return nb.nn.functional.mse_loss(model(x), y)
@nb.compile
def train_step(model, opt_state, x, y):
"""Compiled training step: forward + backward + optimizer update."""
loss, grads = nb.value_and_grad(my_loss_fn, argnums=0)(model, x, y)
model, opt_state = nb.nn.optim.adamw_update(
model, grads, opt_state, lr=1e-2
)
return model, opt_state, loss
[11]:
# Setup data and model
np.random.seed(0)
X = nb.Tensor.from_dlpack(np.random.randn(100, 4).astype(np.float32))
y = nb.Tensor.from_dlpack(np.random.randn(100, 1).astype(np.float32))
model = TinyMLP()
opt_state = nb.nn.optim.adamw_init(model)
print(f"Compiled training loop:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
for step in range(50):
model, opt_state, loss = train_step(model, opt_state, X, y)
if (step + 1) % 10 == 0:
print(f"{step + 1:<8} {loss.item():<12.6f}")
Compiled training loop:
Step Loss
----------------------
10 0.943157
20 0.891157
30 0.838542
40 0.797913
50 0.764449
7. Compiled Training with JAX-Style Params#
@nb.compile works equally well with dict-based parameters.
[12]:
from nabla.nn.functional import xavier_normal
def init_params():
return {
"w1": xavier_normal((4, 16)),
"b1": nb.zeros((1, 16)),
"w2": xavier_normal((16, 1)),
"b2": nb.zeros((1, 1)),
}
def forward(params, x):
h = nb.relu(x @ params["w1"] + params["b1"])
return h @ params["w2"] + params["b2"]
def jax_loss_fn(params, x, y):
pred = forward(params, x)
diff = pred - y
return nb.mean(diff * diff)
Compiled Training with Parameter Dicts#
Compilation works the same way with JAX-style parameter dicts — the entire value_and_grad → optimizer update flow gets fused into one graph:
[13]:
@nb.compile
def jax_train_step(params, opt_state, x, y):
loss, grads = nb.value_and_grad(jax_loss_fn, argnums=0)(params, x, y)
params, opt_state = nb.nn.optim.adamw_update(
params, grads, opt_state, lr=1e-2
)
return params, opt_state, loss
params = init_params()
opt_state = nb.nn.optim.adamw_init(params)
print(f"Compiled JAX-style training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
for step in range(50):
params, opt_state, loss = jax_train_step(params, opt_state, X, y)
if (step + 1) % 10 == 0:
print(f"{step + 1:<8} {loss.item():<12.6f}")
Compiled JAX-style training:
Step Loss
----------------------
10 0.943157
20 0.891157
30 0.838542
40 0.797913
50 0.764449
Summary#
Transform |
Usage |
Key benefit |
|---|---|---|
|
Auto-batch any function |
No manual batching |
|
Per-example gradients |
Efficient |
|
Full Jacobians |
Compose for Hessians |
|
Compile train step |
5–50x speedup |
All transforms compose freely with each other: compile(vmap(grad(f))), jacfwd(jacrev(f)), etc.