Example 13: CNN Training#
This notebook trains a small 2D CNN on a synthetic regression task and shows the same training step written twice — eager and compiled — then benchmarks them head-to-head.
Version |
Code |
When to use |
|---|---|---|
Eager (baseline) |
plain function |
debugging, single shots |
Compiled |
|
repeated training loops |
CNN architecture:
Stage 1:
conv2d (1→8 ch, 3×3) + ReLU + avg_pool2d 2×2Stage 2:
conv2d (8→16 ch, 3×3) + ReLU + max_pool2d 2×2Head: flatten →
matmul + bias + ReLU
1. Imports#
[1]:
from __future__ import annotations
import time
import numpy as np
import nabla as nb
print("Nabla CNN — Eager vs. Compiled")
Nabla CNN — Eager vs. Compiled
2. Synthetic Dataset#
Task: predict the mean squared activation of the center 8×8 patch of a 16×16 grayscale image. This gives a cheap, differentiable regression target that depends on a spatial crop.
[2]:
def make_dataset(seed: int = 0, batch_size: int = 64):
rng = np.random.default_rng(seed)
x = rng.normal(size=(batch_size, 16, 16, 1)).astype(np.float32)
center = x[:, 4:12, 4:12, :]
y = np.mean(center ** 2, axis=(1, 2, 3), keepdims=True).astype(np.float32)
return (
nb.Tensor.from_dlpack(x),
nb.Tensor.from_dlpack(y),
)
X, Y = make_dataset(seed=0)
print(f"Inputs: {X.shape}")
print(f"Targets: {Y.shape}")
Inputs: [Dim(64), Dim(16), Dim(16), Dim(1)]
Targets: [Dim(64), Dim(1), Dim(1), Dim(1)]
3. CNN Architecture#
The model is a pure function over a flat parameter list — no module classes, no hidden state. This is Nabla’s functional API, analogous to JAX.
[3]:
def cnn(x: nb.Tensor, params: list[nb.Tensor]) -> nb.Tensor:
w1, b1, w2, b2, wh, bh = params
# Stage 1: conv + ReLU + avg pool
y = nb.relu(nb.conv2d(x, w1, bias=b1, stride=(1, 1), padding=(1, 1, 1, 1)))
y = nb.avg_pool2d(y, kernel_size=(2, 2), stride=(2, 2), padding=0)
# Stage 2: conv + ReLU + max pool
y = nb.relu(nb.conv2d(y, w2, bias=b2, stride=(1, 1), padding=(1, 1, 1, 1)))
y = nb.max_pool2d(y, kernel_size=(2, 2), stride=(2, 2), padding=0)
# Head: flatten → linear
y = nb.reshape(y, (int(y.shape[0]), int(y.shape[1] * y.shape[2] * y.shape[3])))
return nb.relu(nb.matmul(y, wh) + bh)
print("cnn(x, params) defined")
cnn(x, params) defined
4. Parameter Initialization#
[4]:
def init_params(seed: int = 1) -> list[nb.Tensor]:
rng = np.random.default_rng(seed)
return [
nb.Tensor.from_dlpack((0.10 * rng.normal(size=(3, 3, 1, 8))).astype(np.float32)),
nb.Tensor.from_dlpack(np.zeros((8,), dtype=np.float32)),
nb.Tensor.from_dlpack((0.08 * rng.normal(size=(3, 3, 8, 16))).astype(np.float32)),
nb.Tensor.from_dlpack(np.zeros((16,), dtype=np.float32)),
nb.Tensor.from_dlpack((0.10 * rng.normal(size=(16 * 4 * 4, 1))).astype(np.float32)),
nb.Tensor.from_dlpack(np.zeros((1,), dtype=np.float32)),
]
demo_params = init_params()
print("Parameter shapes:")
for p in demo_params:
print(f" {p.shape}")
Parameter shapes:
[Dim(3), Dim(3), Dim(1), Dim(8)]
[Dim(8)]
[Dim(3), Dim(3), Dim(8), Dim(16)]
[Dim(16)]
[Dim(256), Dim(1)]
[Dim(1)]
5. Loss Function#
Plain MSE loss. Both training variants (eager and compiled) share this function unchanged.
[5]:
def loss_fn(params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor) -> nb.Tensor:
diff = cnn(x, params) - y
return nb.mean(diff * diff)
6. Eager Training Step (Baseline)#
value_and_grad traces and executes the computation graph on every call. No caching, no compilation overhead — but also no reuse.
Use the eager step for debugging or when you only call it once.
[6]:
def eager_train_step(
params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor, lr: float = 3e-2
):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, x, y)
new_params = [p - lr * g for p, g in zip(params, grads)]
return new_params, loss
7. Compiled Training Step#
Decorating the exact same logic with @nb.compile:
First call — Nabla traces the Python function and compiles it to a MAX graph.
All later calls with the same input shapes/dtypes hit the cache and skip Python dispatch entirely.
The compiled and eager versions produce identical numerical results.
[7]:
@nb.compile
def compiled_train_step(
params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor, lr: float = 3e-2
):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, x, y)
new_params = [p - lr * g for p, g in zip(params, grads)]
return new_params, loss
8. Head-to-Head: Eager vs. Compiled#
We train two identical models from the same seed for 60 steps. The first step is a warmup (for the compiled version this triggers trace + compile) and is excluded from the timing measurement.
What to observe:
Loss curves should be numerically identical.
Compiled should report a lower average
ms/step.
[8]:
def run(step_fn, label: str, steps: int = 60, lr: float = 3e-2, seed: int = 0):
nb._clear_caches()
x, y = make_dataset(seed=seed)
params = init_params(seed=seed + 1)
# Warmup (for compiled: triggers trace + compile)
params, loss_warmup = step_fn(params, x, y, lr)
nb.realize_all(loss_warmup, *params)
print(f"\n{label}")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
losses = []
t0 = time.perf_counter()
for step in range(steps):
params, loss = step_fn(params, x, y, lr)
nb.realize_all(loss, *params)
loss_value = float(loss.item())
losses.append(loss_value)
if (step + 1) % 10 == 0:
print(f"{step + 1:<8} {loss_value:<12.6f}")
avg_ms = (time.perf_counter() - t0) / steps * 1000.0
print(f"Avg step: {avg_ms:.1f} ms/step")
return {
"avg_ms": avg_ms,
"initial_loss": losses[0],
"final_loss": losses[-1],
"losses": losses,
}
9. Execute Benchmark#
Run both variants with identical data and initialization. This gives an apples-to-apples performance comparison.
[9]:
eager_result = run(eager_train_step, "Eager baseline", steps=60, lr=3e-2, seed=0)
compiled_result = run(compiled_train_step, "Compiled (@nb.compile)", steps=60, lr=3e-2, seed=0)
# Backward compatibility for partially-run kernels where run() may still return float
if isinstance(eager_result, float):
eager_result = {"avg_ms": eager_result, "final_loss": float("nan")}
if isinstance(compiled_result, float):
compiled_result = {"avg_ms": compiled_result, "final_loss": float("nan")}
speedup = eager_result["avg_ms"] / max(compiled_result["avg_ms"], 1e-9)
print("\nSummary")
print("-" * 40)
print(f"Eager avg step: {eager_result['avg_ms']:.2f} ms")
print(f"Compiled avg step: {compiled_result['avg_ms']:.2f} ms")
print(f"Speedup: {speedup:.2f}x")
print(
f"Loss check (eager final / compiled final): "
f"{eager_result.get('final_loss', float('nan')):.6f} / "
f"{compiled_result.get('final_loss', float('nan')):.6f}"
)
print(f"Compiled cache stats: {compiled_train_step.stats}")
Eager baseline
Step Loss
----------------------
10 0.029387
20 0.029257
30 0.029142
40 0.029040
50 0.028946
60 0.028860
Avg step: 207.1 ms/step
Compiled (@nb.compile)
Step Loss
----------------------
10 0.029387
20 0.029257
30 0.029142
40 0.029040
50 0.028946
60 0.028860
Avg step: 28.5 ms/step
Summary
----------------------------------------
Eager avg step: 207.10 ms
Compiled avg step: 28.51 ms
Speedup: 7.26x
Loss check (eager final / compiled final): 0.028860 / 0.028860
Compiled cache stats: CompilationStats(hits=60, misses=1, fallbacks=0, hit_rate=98.4%)