Example 9: Compiled vs Eager vs JAX#
This benchmark compares three training modes on the same MLP:
Mode |
Description |
|---|---|
|
Fused graph execution (fastest) |
Eager (deferred) |
Lazy evaluation with |
Eager (MAX graph) |
Builds MAX graph each step |
JAX |
Google’s XLA-backed JIT (if installed) |
The task: fit \(f(x) = \frac{\sin(8\pi x) + 1}{2}\) with a 9-layer MLP.
[1]:
import time
import numpy as np
import nabla as nb
# Try to import JAX
try:
import jax
import jax.numpy as jnp
from jax import grad, jit
HAS_JAX = True
except ImportError:
HAS_JAX = False
1. Dataset and Parameter Initialization#
[2]:
np.random.seed(42)
n_samples = 500
n_steps = 200
X_np = np.linspace(0, 1, n_samples).reshape(-1, 1).astype(np.float32)
y_np = (np.sin(8 * np.pi * X_np) + 1) / 2.0
X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print(f"Dataset: {n_samples} samples, fitting (sin(8π·x) + 1)/2")
Dataset: 500 samples, fitting (sin(8π·x) + 1)/2
Model Architecture#
A 9-layer MLP with Xavier initialization. All three backends will train the same architecture:
[3]:
layers = [1, 16, 32, 64, 64, 64, 64, 32, 16, 1]
# Xavier initialization
params = {}
for i in range(len(layers) - 1):
in_dim, out_dim = layers[i], layers[i + 1]
limit = np.sqrt(6.0 / (in_dim + out_dim))
params[f"layer{i + 1}"] = {
"w": nb.Tensor.from_dlpack(np.random.uniform(-limit, limit, (in_dim, out_dim)).astype(np.float32)),
"b": nb.Tensor.from_dlpack(np.zeros((out_dim,), dtype=np.float32)),
}
total_params = sum(layers[i] * layers[i + 1] + layers[i + 1] for i in range(len(layers) - 1))
print(f"Architecture: {' → '.join(map(str, layers))} ({total_params} params)")
Architecture: 1 → 16 → 32 → 64 → 64 → 64 → 64 → 32 → 16 → 1 (17793 params)
Forward, Loss, and Train Steps#
We define the forward pass, MSE loss, and two training modes:
Compiled (
@nb.compile): the entire step (forward + backward + update) is fusedEager: deferred evaluation with manual
realize_all
[4]:
def mlp_forward(params, x):
h = x
for i in range(1, len(layers)):
h = h @ params[f"layer{i}"]["w"] + params[f"layer{i}"]["b"]
if i < len(layers) - 1:
h = nb.relu(h)
return h
def loss_fn(params, x, y):
pred = mlp_forward(params, x)
diff = pred - y
return nb.mean(diff * diff)
@nb.compile
def train_step_compiled(params, x, y):
loss, grads = nb.value_and_grad(loss_fn)(params, x, y)
lr = 0.01
new_params = {}
for layer_name in params:
new_params[layer_name] = {
"w": params[layer_name]["w"] - grads[layer_name]["w"] * lr,
"b": params[layer_name]["b"] - grads[layer_name]["b"] * lr,
}
return loss, new_params
def train_step_eager(params, x, y):
loss, grads = nb.value_and_grad(loss_fn, realize=False)(params, x, y)
lr = 0.01
new_params = {}
for layer_name in params:
new_params[layer_name] = {
"w": params[layer_name]["w"] - grads[layer_name]["w"] * lr,
"b": params[layer_name]["b"] - grads[layer_name]["b"] * lr,
}
# Batch-realize all outputs
all_outputs = [loss]
for lp in new_params.values():
all_outputs.extend(lp.values())
nb.realize_all(*all_outputs)
return loss, new_params
2. Nabla Benchmarks (Compiled vs Eager)#
[5]:
params_compiled = params
# Warmup (triggers compilation)
loss, params_compiled = train_step_compiled(params_compiled, X, y)
print(f"Warmup loss: {loss.to_numpy():.6f}")
# Timed run
start = time.perf_counter()
losses_compiled = []
for i in range(n_steps):
loss, params_compiled = train_step_compiled(params_compiled, X, y)
losses_compiled.append(float(loss.to_numpy()))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_compiled = time.perf_counter() - start
print(f"\nCompiled: {elapsed_compiled:.4f}s ({n_steps / elapsed_compiled:.1f} steps/sec)")
print(f"Loss: {losses_compiled[0]:.6f} → {losses_compiled[-1]:.6f}")
print(f"Compile stats: {train_step_compiled.stats}")
Warmup loss: 0.364887
Step 50: loss = 0.127728
Step 100: loss = 0.126412
Step 150: loss = 0.125905
Step 200: loss = 0.125536
Compiled: 0.2810s (711.6 steps/sec)
Loss: 0.337955 → 0.125536
Compile stats: CompilationStats(hits=200, misses=1, fallbacks=0, hit_rate=99.5%)
Eager (Deferred Evaluation)#
[6]:
params_eager = params
loss, params_eager = train_step_eager(params_eager, X, y)
print(f"Warmup loss: {loss.to_numpy():.6f}")
start = time.perf_counter()
losses_eager = []
for i in range(n_steps):
loss, params_eager = train_step_eager(params_eager, X, y)
losses_eager.append(float(loss.to_numpy()))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_eager = time.perf_counter() - start
print(f"\nEager: {elapsed_eager:.4f}s ({n_steps / elapsed_eager:.1f} steps/sec)")
print(f"Loss: {losses_eager[0]:.6f} → {losses_eager[-1]:.6f}")
Warmup loss: 0.364887
Step 50: loss = 0.127728
Step 100: loss = 0.126412
Step 150: loss = 0.125905
Step 200: loss = 0.125536
Eager: 2.4252s (82.5 steps/sec)
Loss: 0.337955 → 0.125536
Eager (MAX Graph Mode)#
EAGER_MAX_GRAPH=True builds a MAX execution graph for every step. This is typically slower than deferred but avoids Python-level overhead:
[7]:
import nabla.config as nabla_config
orig_eager_max = nabla_config.EAGER_MAX_GRAPH
nabla_config.EAGER_MAX_GRAPH = True
params_eager_max = params
loss, params_eager_max = train_step_eager(params_eager_max, X, y)
print(f"Warmup loss: {loss.to_numpy():.6f}")
start = time.perf_counter()
losses_eager_max = []
for i in range(n_steps):
loss, params_eager_max = train_step_eager(params_eager_max, X, y)
losses_eager_max.append(float(loss.to_numpy()))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_eager_max = time.perf_counter() - start
nabla_config.EAGER_MAX_GRAPH = orig_eager_max # restore
print(f"\nEager MAX: {elapsed_eager_max:.4f}s ({n_steps / elapsed_eager_max:.1f} steps/sec)")
print(f"Loss: {losses_eager_max[0]:.6f} → {losses_eager_max[-1]:.6f}")
Warmup loss: 0.364887
Step 50: loss = 0.127728
Step 100: loss = 0.126412
Step 150: loss = 0.125905
Step 200: loss = 0.125536
Eager MAX: 10.0300s (19.9 steps/sec)
Loss: 0.337955 → 0.125536
3. JAX @jit Comparison (Optional)#
If JAX is installed, we run the same MLP training with @jax.jit for a direct performance comparison. The architecture and hyperparameters are identical:
[8]:
if HAS_JAX:
# Convert params to flat list for JAX
jax_params = []
for layer_name in sorted(params.keys()):
jax_params.append(jnp.array(params[layer_name]["w"].to_numpy()))
jax_params.append(jnp.array(params[layer_name]["b"].to_numpy()))
X_jax, y_jax = jnp.array(X_np), jnp.array(y_np)
def jax_mlp(params_flat, x):
h = x
for i in range(0, len(params_flat) - 2, 2):
h = h @ params_flat[i] + params_flat[i + 1]
h = jax.nn.relu(h)
return h @ params_flat[-2] + params_flat[-1]
def jax_loss(params_flat, x, y):
return jnp.mean((jax_mlp(params_flat, x) - y) ** 2)
@jit
def jax_train_step(params_flat, x, y):
loss = jax_loss(params_flat, x, y)
grads = grad(jax_loss)(params_flat, x, y)
return loss, [p - g * 0.01 for p, g in zip(params_flat, grads, strict=False)]
# Warmup
loss_jax, jax_params = jax_train_step(jax_params, X_jax, y_jax)
jax.block_until_ready(loss_jax)
print(f"JAX warmup loss: {float(loss_jax):.6f}")
start = time.perf_counter()
losses_jax = []
for i in range(n_steps):
loss_jax, jax_params = jax_train_step(jax_params, X_jax, y_jax)
jax.block_until_ready(loss_jax)
losses_jax.append(float(loss_jax))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {float(loss_jax):.6f}")
elapsed_jax = time.perf_counter() - start
print(f"\nJAX JIT: {elapsed_jax:.4f}s ({n_steps / elapsed_jax:.1f} steps/sec)")
print(f"Loss: {losses_jax[0]:.6f} → {losses_jax[-1]:.6f}")
else:
print("JAX not installed — skipping JAX benchmark")
JAX warmup loss: 0.364887
Step 50: loss = 0.127728
Step 100: loss = 0.126412
Step 150: loss = 0.125905
Step 200: loss = 0.125536
JAX JIT: 0.0953s (2097.8 steps/sec)
Loss: 0.337955 → 0.125536
4. Results Summary#
[9]:
print("=" * 60)
print("PERFORMANCE SUMMARY")
print("=" * 60)
print(f"Nabla @nb.compile: {elapsed_compiled:.4f}s ({n_steps / elapsed_compiled:.1f} steps/sec)")
print(f"Nabla Eager: {elapsed_eager:.4f}s ({n_steps / elapsed_eager:.1f} steps/sec)")
print(f"Nabla Eager (MAX): {elapsed_eager_max:.4f}s ({n_steps / elapsed_eager_max:.1f} steps/sec)")
if HAS_JAX:
print(f"JAX @jit: {elapsed_jax:.4f}s ({n_steps / elapsed_jax:.1f} steps/sec)")
speedup_vs_jax = elapsed_jax / elapsed_compiled
if speedup_vs_jax > 1:
print(f"\n🚀 Nabla compiled is {speedup_vs_jax:.2f}x faster than JAX JIT")
else:
print(f"\nJAX JIT is {1 / speedup_vs_jax:.2f}x faster than Nabla compiled")
speedup = elapsed_eager / elapsed_compiled
print(f"\nCompile speedup over eager: {speedup:.2f}x")
# Verify correctness across modes
loss_diff = abs(losses_compiled[-1] - losses_eager[-1])
print(f"Loss match (compiled vs eager): {'✅' if loss_diff < 1e-4 else '⚠️'} diff={loss_diff:.8f}")
============================================================
PERFORMANCE SUMMARY
============================================================
Nabla @nb.compile: 0.2810s (711.6 steps/sec)
Nabla Eager: 2.4252s (82.5 steps/sec)
Nabla Eager (MAX): 10.0300s (19.9 steps/sec)
JAX @jit: 0.0953s (2097.8 steps/sec)
JAX JIT is 2.95x faster than Nabla compiled
Compile speedup over eager: 8.63x
Loss match (compiled vs eager): ✅ diff=0.00000000
Key takeaways:
@nb.compilefuses the entire train step into a single optimized graphEager mode is slower due to per-op dispatch overhead
EAGER_MAX_GRAPHmode builds a MAX graph each step — useful for debuggingAll three modes produce numerically identical results
4. Summary#
[10]:
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print("✓ MLP training works with compile!")
print("✓ Full pytree parameters (weights + biases) work correctly")
print(
f"✓ Loss decreases properly: {losses_compiled[0]:.6f} -> {losses_compiled[-1]:.6f}"
)
print(f"✓ {speedup:.2f}x speedup from compilation")
print(f"✓ Cache hit rate: {train_step_compiled.stats.hit_rate:.1f}%")
if HAS_JAX:
print("✓ Compared against JAX JIT successfully")
======================================================================
SUMMARY
======================================================================
✓ MLP training works with compile!
✓ Full pytree parameters (weights + biases) work correctly
✓ Loss decreases properly: 0.337955 -> 0.125536
✓ 8.63x speedup from compilation
✓ Cache hit rate: 99.5%
✓ Compared against JAX JIT successfully