Example 9: Compiled vs Eager vs JAX#

This benchmark compares three training modes on the same MLP:

Mode

Description

@nb.compile

Fused graph execution (fastest)

Eager (deferred)

Lazy evaluation with realize_all

Eager (MAX graph)

Builds MAX graph each step

JAX @jit

Google’s XLA-backed JIT (if installed)

The task: fit \(f(x) = \frac{\sin(8\pi x) + 1}{2}\) with a 9-layer MLP.

[1]:
import time

import numpy as np

import nabla as nb

# Try to import JAX
try:
    import jax
    import jax.numpy as jnp
    from jax import grad, jit

    HAS_JAX = True
except ImportError:
    HAS_JAX = False

1. Dataset and Parameter Initialization#

[2]:
np.random.seed(42)
n_samples = 500
n_steps = 200

X_np = np.linspace(0, 1, n_samples).reshape(-1, 1).astype(np.float32)
y_np = (np.sin(8 * np.pi * X_np) + 1) / 2.0

X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)

print(f"Dataset: {n_samples} samples, fitting (sin(8π·x) + 1)/2")
Dataset: 500 samples, fitting (sin(8π·x) + 1)/2

Model Architecture#

A 9-layer MLP with Xavier initialization. All three backends will train the same architecture:

[3]:
layers = [1, 16, 32, 64, 64, 64, 64, 32, 16, 1]

# Xavier initialization
params = {}
for i in range(len(layers) - 1):
    in_dim, out_dim = layers[i], layers[i + 1]
    limit = np.sqrt(6.0 / (in_dim + out_dim))
    params[f"layer{i + 1}"] = {
        "w": nb.Tensor.from_dlpack(np.random.uniform(-limit, limit, (in_dim, out_dim)).astype(np.float32)),
        "b": nb.Tensor.from_dlpack(np.zeros((out_dim,), dtype=np.float32)),
    }

total_params = sum(layers[i] * layers[i + 1] + layers[i + 1] for i in range(len(layers) - 1))
print(f"Architecture: {' → '.join(map(str, layers))} ({total_params} params)")
Architecture: 1 → 16 → 32 → 64 → 64 → 64 → 64 → 32 → 16 → 1 (17793 params)

Forward, Loss, and Train Steps#

We define the forward pass, MSE loss, and two training modes:

  • Compiled (@nb.compile): the entire step (forward + backward + update) is fused

  • Eager: deferred evaluation with manual realize_all

[4]:
def mlp_forward(params, x):
    h = x
    for i in range(1, len(layers)):
        h = h @ params[f"layer{i}"]["w"] + params[f"layer{i}"]["b"]
        if i < len(layers) - 1:
            h = nb.relu(h)
    return h


def loss_fn(params, x, y):
    pred = mlp_forward(params, x)
    diff = pred - y
    return nb.mean(diff * diff)


@nb.compile
def train_step_compiled(params, x, y):
    loss, grads = nb.value_and_grad(loss_fn)(params, x, y)
    lr = 0.01
    new_params = {}
    for layer_name in params:
        new_params[layer_name] = {
            "w": params[layer_name]["w"] - grads[layer_name]["w"] * lr,
            "b": params[layer_name]["b"] - grads[layer_name]["b"] * lr,
        }
    return loss, new_params


def train_step_eager(params, x, y):
    loss, grads = nb.value_and_grad(loss_fn, realize=False)(params, x, y)
    lr = 0.01
    new_params = {}
    for layer_name in params:
        new_params[layer_name] = {
            "w": params[layer_name]["w"] - grads[layer_name]["w"] * lr,
            "b": params[layer_name]["b"] - grads[layer_name]["b"] * lr,
        }
    # Batch-realize all outputs
    all_outputs = [loss]
    for lp in new_params.values():
        all_outputs.extend(lp.values())
    nb.realize_all(*all_outputs)
    return loss, new_params

2. Nabla Benchmarks (Compiled vs Eager)#

[5]:
params_compiled = params

# Warmup (triggers compilation)
loss, params_compiled = train_step_compiled(params_compiled, X, y)
print(f"Warmup loss: {loss.to_numpy():.6f}")

# Timed run
start = time.perf_counter()
losses_compiled = []
for i in range(n_steps):
    loss, params_compiled = train_step_compiled(params_compiled, X, y)
    losses_compiled.append(float(loss.to_numpy()))
    if (i + 1) % 50 == 0:
        print(f"  Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_compiled = time.perf_counter() - start

print(f"\nCompiled: {elapsed_compiled:.4f}s ({n_steps / elapsed_compiled:.1f} steps/sec)")
print(f"Loss: {losses_compiled[0]:.6f}{losses_compiled[-1]:.6f}")
print(f"Compile stats: {train_step_compiled.stats}")
Warmup loss: 0.364887
  Step  50: loss = 0.127728
  Step 100: loss = 0.126412
  Step 150: loss = 0.125905
  Step 200: loss = 0.125536

Compiled: 0.2810s (711.6 steps/sec)
Loss: 0.337955 → 0.125536
Compile stats: CompilationStats(hits=200, misses=1, fallbacks=0, hit_rate=99.5%)

Eager (Deferred Evaluation)#

[6]:
params_eager = params

loss, params_eager = train_step_eager(params_eager, X, y)
print(f"Warmup loss: {loss.to_numpy():.6f}")

start = time.perf_counter()
losses_eager = []
for i in range(n_steps):
    loss, params_eager = train_step_eager(params_eager, X, y)
    losses_eager.append(float(loss.to_numpy()))
    if (i + 1) % 50 == 0:
        print(f"  Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_eager = time.perf_counter() - start

print(f"\nEager: {elapsed_eager:.4f}s ({n_steps / elapsed_eager:.1f} steps/sec)")
print(f"Loss: {losses_eager[0]:.6f}{losses_eager[-1]:.6f}")
Warmup loss: 0.364887
  Step  50: loss = 0.127728
  Step 100: loss = 0.126412
  Step 150: loss = 0.125905
  Step 200: loss = 0.125536

Eager: 2.4252s (82.5 steps/sec)
Loss: 0.337955 → 0.125536

Eager (MAX Graph Mode)#

EAGER_MAX_GRAPH=True builds a MAX execution graph for every step. This is typically slower than deferred but avoids Python-level overhead:

[7]:
import nabla.config as nabla_config
orig_eager_max = nabla_config.EAGER_MAX_GRAPH
nabla_config.EAGER_MAX_GRAPH = True

params_eager_max = params

loss, params_eager_max = train_step_eager(params_eager_max, X, y)
print(f"Warmup loss: {loss.to_numpy():.6f}")

start = time.perf_counter()
losses_eager_max = []
for i in range(n_steps):
    loss, params_eager_max = train_step_eager(params_eager_max, X, y)
    losses_eager_max.append(float(loss.to_numpy()))
    if (i + 1) % 50 == 0:
        print(f"  Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_eager_max = time.perf_counter() - start

nabla_config.EAGER_MAX_GRAPH = orig_eager_max  # restore
print(f"\nEager MAX: {elapsed_eager_max:.4f}s ({n_steps / elapsed_eager_max:.1f} steps/sec)")
print(f"Loss: {losses_eager_max[0]:.6f}{losses_eager_max[-1]:.6f}")
Warmup loss: 0.364887
  Step  50: loss = 0.127728
  Step 100: loss = 0.126412
  Step 150: loss = 0.125905
  Step 200: loss = 0.125536

Eager MAX: 10.0300s (19.9 steps/sec)
Loss: 0.337955 → 0.125536

3. JAX @jit Comparison (Optional)#

If JAX is installed, we run the same MLP training with @jax.jit for a direct performance comparison. The architecture and hyperparameters are identical:

[8]:
if HAS_JAX:
    # Convert params to flat list for JAX
    jax_params = []
    for layer_name in sorted(params.keys()):
        jax_params.append(jnp.array(params[layer_name]["w"].to_numpy()))
        jax_params.append(jnp.array(params[layer_name]["b"].to_numpy()))
    X_jax, y_jax = jnp.array(X_np), jnp.array(y_np)

    def jax_mlp(params_flat, x):
        h = x
        for i in range(0, len(params_flat) - 2, 2):
            h = h @ params_flat[i] + params_flat[i + 1]
            h = jax.nn.relu(h)
        return h @ params_flat[-2] + params_flat[-1]

    def jax_loss(params_flat, x, y):
        return jnp.mean((jax_mlp(params_flat, x) - y) ** 2)

    @jit
    def jax_train_step(params_flat, x, y):
        loss = jax_loss(params_flat, x, y)
        grads = grad(jax_loss)(params_flat, x, y)
        return loss, [p - g * 0.01 for p, g in zip(params_flat, grads, strict=False)]

    # Warmup
    loss_jax, jax_params = jax_train_step(jax_params, X_jax, y_jax)
    jax.block_until_ready(loss_jax)
    print(f"JAX warmup loss: {float(loss_jax):.6f}")

    start = time.perf_counter()
    losses_jax = []
    for i in range(n_steps):
        loss_jax, jax_params = jax_train_step(jax_params, X_jax, y_jax)
        jax.block_until_ready(loss_jax)
        losses_jax.append(float(loss_jax))
        if (i + 1) % 50 == 0:
            print(f"  Step {i + 1:3d}: loss = {float(loss_jax):.6f}")
    elapsed_jax = time.perf_counter() - start

    print(f"\nJAX JIT: {elapsed_jax:.4f}s ({n_steps / elapsed_jax:.1f} steps/sec)")
    print(f"Loss: {losses_jax[0]:.6f}{losses_jax[-1]:.6f}")
else:
    print("JAX not installed — skipping JAX benchmark")
JAX warmup loss: 0.364887
  Step  50: loss = 0.127728
  Step 100: loss = 0.126412
  Step 150: loss = 0.125905
  Step 200: loss = 0.125536

JAX JIT: 0.0953s (2097.8 steps/sec)
Loss: 0.337955 → 0.125536

4. Results Summary#

[9]:
print("=" * 60)
print("PERFORMANCE SUMMARY")
print("=" * 60)
print(f"Nabla @nb.compile:  {elapsed_compiled:.4f}s  ({n_steps / elapsed_compiled:.1f} steps/sec)")
print(f"Nabla Eager:        {elapsed_eager:.4f}s  ({n_steps / elapsed_eager:.1f} steps/sec)")
print(f"Nabla Eager (MAX):  {elapsed_eager_max:.4f}s  ({n_steps / elapsed_eager_max:.1f} steps/sec)")

if HAS_JAX:
    print(f"JAX @jit:           {elapsed_jax:.4f}s  ({n_steps / elapsed_jax:.1f} steps/sec)")
    speedup_vs_jax = elapsed_jax / elapsed_compiled
    if speedup_vs_jax > 1:
        print(f"\n🚀 Nabla compiled is {speedup_vs_jax:.2f}x faster than JAX JIT")
    else:
        print(f"\nJAX JIT is {1 / speedup_vs_jax:.2f}x faster than Nabla compiled")

speedup = elapsed_eager / elapsed_compiled
print(f"\nCompile speedup over eager: {speedup:.2f}x")

# Verify correctness across modes
loss_diff = abs(losses_compiled[-1] - losses_eager[-1])
print(f"Loss match (compiled vs eager): {'✅' if loss_diff < 1e-4 else '⚠️'} diff={loss_diff:.8f}")
============================================================
PERFORMANCE SUMMARY
============================================================
Nabla @nb.compile:  0.2810s  (711.6 steps/sec)
Nabla Eager:        2.4252s  (82.5 steps/sec)
Nabla Eager (MAX):  10.0300s  (19.9 steps/sec)
JAX @jit:           0.0953s  (2097.8 steps/sec)

JAX JIT is 2.95x faster than Nabla compiled

Compile speedup over eager: 8.63x
Loss match (compiled vs eager): ✅ diff=0.00000000

Key takeaways:

  • @nb.compile fuses the entire train step into a single optimized graph

  • Eager mode is slower due to per-op dispatch overhead

  • EAGER_MAX_GRAPH mode builds a MAX graph each step — useful for debugging

  • All three modes produce numerically identical results

4. Summary#

[10]:
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print("✓ MLP training works with compile!")
print("✓ Full pytree parameters (weights + biases) work correctly")
print(
    f"✓ Loss decreases properly: {losses_compiled[0]:.6f} -> {losses_compiled[-1]:.6f}"
)
print(f"✓ {speedup:.2f}x speedup from compilation")
print(f"✓ Cache hit rate: {train_step_compiled.stats.hit_rate:.1f}%")
if HAS_JAX:
    print("✓ Compared against JAX JIT successfully")
======================================================================
SUMMARY
======================================================================
✓ MLP training works with compile!
✓ Full pytree parameters (weights + biases) work correctly
✓ Loss decreases properly: 0.337955 -> 0.125536
✓ 8.63x speedup from compilation
✓ Cache hit rate: 99.5%
✓ Compared against JAX JIT successfully