Example 13: CNN Training#

This notebook trains a small 2D CNN on a synthetic regression task and shows the same training step written twice — eager and compiled — then benchmarks them head-to-head.

Version

Code

When to use

Eager (baseline)

plain function

debugging, single shots

Compiled

@nb.compile

repeated training loops

CNN architecture:

  • Stage 1: conv2d (1→8 ch, 3×3) + ReLU + avg_pool2d 2×2

  • Stage 2: conv2d (8→16 ch, 3×3) + ReLU + max_pool2d 2×2

  • Head: flatten → matmul + bias + ReLU

1. Imports#

[1]:
from __future__ import annotations

import time

import numpy as np

import nabla as nb

print("Nabla CNN — Eager vs. Compiled")
Nabla CNN — Eager vs. Compiled

2. Synthetic Dataset#

Task: predict the mean squared activation of the center 8×8 patch of a 16×16 grayscale image. This gives a cheap, differentiable regression target that depends on a spatial crop.

[2]:
def make_dataset(seed: int = 0, batch_size: int = 64):
    rng = np.random.default_rng(seed)
    x = rng.normal(size=(batch_size, 16, 16, 1)).astype(np.float32)
    center = x[:, 4:12, 4:12, :]
    y = np.mean(center ** 2, axis=(1, 2, 3), keepdims=True).astype(np.float32)
    return (
        nb.Tensor.from_dlpack(x),
        nb.Tensor.from_dlpack(y),
    )


X, Y = make_dataset(seed=0)
print(f"Inputs:  {X.shape}")
print(f"Targets: {Y.shape}")
Inputs:  [Dim(64), Dim(16), Dim(16), Dim(1)]
Targets: [Dim(64), Dim(1), Dim(1), Dim(1)]

3. CNN Architecture#

The model is a pure function over a flat parameter list — no module classes, no hidden state. This is Nabla’s functional API, analogous to JAX.

[3]:
def cnn(x: nb.Tensor, params: list[nb.Tensor]) -> nb.Tensor:
    w1, b1, w2, b2, wh, bh = params

    # Stage 1: conv + ReLU + avg pool
    y = nb.relu(nb.conv2d(x, w1, bias=b1, stride=(1, 1), padding=(1, 1, 1, 1)))
    y = nb.avg_pool2d(y, kernel_size=(2, 2), stride=(2, 2), padding=0)

    # Stage 2: conv + ReLU + max pool
    y = nb.relu(nb.conv2d(y, w2, bias=b2, stride=(1, 1), padding=(1, 1, 1, 1)))
    y = nb.max_pool2d(y, kernel_size=(2, 2), stride=(2, 2), padding=0)

    # Head: flatten → linear
    y = nb.reshape(y, (int(y.shape[0]), int(y.shape[1] * y.shape[2] * y.shape[3])))
    return nb.relu(nb.matmul(y, wh) + bh)


print("cnn(x, params) defined")
cnn(x, params) defined

4. Parameter Initialization#

[4]:
def init_params(seed: int = 1) -> list[nb.Tensor]:
    rng = np.random.default_rng(seed)

    return [
        nb.Tensor.from_dlpack((0.10 * rng.normal(size=(3, 3, 1, 8))).astype(np.float32)),
        nb.Tensor.from_dlpack(np.zeros((8,), dtype=np.float32)),
        nb.Tensor.from_dlpack((0.08 * rng.normal(size=(3, 3, 8, 16))).astype(np.float32)),
        nb.Tensor.from_dlpack(np.zeros((16,), dtype=np.float32)),
        nb.Tensor.from_dlpack((0.10 * rng.normal(size=(16 * 4 * 4, 1))).astype(np.float32)),
        nb.Tensor.from_dlpack(np.zeros((1,), dtype=np.float32)),
    ]


demo_params = init_params()
print("Parameter shapes:")
for p in demo_params:
    print(f"  {p.shape}")
Parameter shapes:
  [Dim(3), Dim(3), Dim(1), Dim(8)]
  [Dim(8)]
  [Dim(3), Dim(3), Dim(8), Dim(16)]
  [Dim(16)]
  [Dim(256), Dim(1)]
  [Dim(1)]

5. Loss Function#

Plain MSE loss. Both training variants (eager and compiled) share this function unchanged.

[5]:
def loss_fn(params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor) -> nb.Tensor:
    diff = cnn(x, params) - y
    return nb.mean(diff * diff)

6. Eager Training Step (Baseline)#

value_and_grad traces and executes the computation graph on every call. No caching, no compilation overhead — but also no reuse.

Use the eager step for debugging or when you only call it once.

[6]:
def eager_train_step(
    params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor, lr: float = 3e-2
):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, x, y)
    new_params = [p - lr * g for p, g in zip(params, grads)]
    return new_params, loss

7. Compiled Training Step#

Decorating the exact same logic with @nb.compile:

  • First call — Nabla traces the Python function and compiles it to a MAX graph.

  • All later calls with the same input shapes/dtypes hit the cache and skip Python dispatch entirely.

The compiled and eager versions produce identical numerical results.

[7]:
@nb.compile
def compiled_train_step(
    params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor, lr: float = 3e-2
):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, x, y)
    new_params = [p - lr * g for p, g in zip(params, grads)]
    return new_params, loss

8. Head-to-Head: Eager vs. Compiled#

We train two identical models from the same seed for 60 steps. The first step is a warmup (for the compiled version this triggers trace + compile) and is excluded from the timing measurement.

What to observe:

  • Loss curves should be numerically identical.

  • Compiled should report a lower average ms/step.

[8]:
def run(step_fn, label: str, steps: int = 60, lr: float = 3e-2, seed: int = 0):
    nb._clear_caches()
    x, y = make_dataset(seed=seed)
    params = init_params(seed=seed + 1)

    # Warmup (for compiled: triggers trace + compile)
    params, loss_warmup = step_fn(params, x, y, lr)
    nb.realize_all(loss_warmup, *params)

    print(f"\n{label}")
    print(f"{'Step':<8} {'Loss':<12}")
    print("-" * 22)

    losses = []
    t0 = time.perf_counter()
    for step in range(steps):
        params, loss = step_fn(params, x, y, lr)
        nb.realize_all(loss, *params)
        loss_value = float(loss.item())
        losses.append(loss_value)
        if (step + 1) % 10 == 0:
            print(f"{step + 1:<8} {loss_value:<12.6f}")

    avg_ms = (time.perf_counter() - t0) / steps * 1000.0
    print(f"Avg step: {avg_ms:.1f} ms/step")
    return {
        "avg_ms": avg_ms,
        "initial_loss": losses[0],
        "final_loss": losses[-1],
        "losses": losses,
    }

9. Execute Benchmark#

Run both variants with identical data and initialization. This gives an apples-to-apples performance comparison.

[9]:
eager_result = run(eager_train_step, "Eager baseline", steps=60, lr=3e-2, seed=0)
compiled_result = run(compiled_train_step, "Compiled (@nb.compile)", steps=60, lr=3e-2, seed=0)

# Backward compatibility for partially-run kernels where run() may still return float
if isinstance(eager_result, float):
    eager_result = {"avg_ms": eager_result, "final_loss": float("nan")}
if isinstance(compiled_result, float):
    compiled_result = {"avg_ms": compiled_result, "final_loss": float("nan")}

speedup = eager_result["avg_ms"] / max(compiled_result["avg_ms"], 1e-9)
print("\nSummary")
print("-" * 40)
print(f"Eager avg step:    {eager_result['avg_ms']:.2f} ms")
print(f"Compiled avg step: {compiled_result['avg_ms']:.2f} ms")
print(f"Speedup:           {speedup:.2f}x")
print(
    f"Loss check (eager final / compiled final): "
    f"{eager_result.get('final_loss', float('nan')):.6f} / "
    f"{compiled_result.get('final_loss', float('nan')):.6f}"
)
print(f"Compiled cache stats: {compiled_train_step.stats}")

Eager baseline
Step     Loss
----------------------
10       0.029387
20       0.029257
30       0.029142
40       0.029040
50       0.028946
60       0.028860
Avg step: 207.1 ms/step

Compiled (@nb.compile)
Step     Loss
----------------------
10       0.029387
20       0.029257
30       0.029142
40       0.029040
50       0.028946
60       0.028860
Avg step: 28.5 ms/step

Summary
----------------------------------------
Eager avg step:    207.10 ms
Compiled avg step: 28.51 ms
Speedup:           7.26x
Loss check (eager final / compiled final): 0.028860 / 0.028860
Compiled cache stats: CompilationStats(hits=60, misses=1, fallbacks=0, hit_rate=98.4%)