Example 9: Compile vs Eager vs JAX#
This benchmark-style example compares three modes:
Nabla compiled training (
@nb.compile)Nabla eager training
JAX
@jittraining (when JAX is installed)
[ ]:
import time
import numpy as np
import nabla as nb
# Try to import JAX
try:
import jax
import jax.numpy as jnp
from jax import grad, jit
HAS_JAX = True
except ImportError:
HAS_JAX = False
1. Dataset and Parameter Initialization#
[ ]:
np.random.seed(42)
n_samples = 500
X_np = np.linspace(0, 1, n_samples).reshape(-1, 1).astype(np.float32)
y_np = (np.sin(8 * np.pi * X_np) + 1) / 2.0
X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print("=" * 70)
print("MLP Training: Fitting Complex Sine Curve")
print("=" * 70)
print(f"Dataset: {n_samples} samples, fitting (sin(8π*x) + 1)/2")
layers = [1, 16, 32, 64, 64, 64, 64, 32, 16, 1]
params = {}
for i in range(len(layers) - 1):
in_dim, out_dim = layers[i], layers[i + 1]
limit = np.sqrt(6.0 / (in_dim + out_dim))
w_np = np.random.uniform(-limit, limit, (in_dim, out_dim)).astype(np.float32)
b_np = np.zeros((out_dim,), dtype=np.float32)
params[f"layer{i + 1}"] = {
"w": nb.Tensor.from_dlpack(w_np),
"b": nb.Tensor.from_dlpack(b_np),
}
total_params = sum(
(layers[i] * layers[i + 1] + layers[i + 1]) for i in range(len(layers) - 1)
)
print(f"Architecture: {' -> '.join(map(str, layers))} ({total_params} params)\n")
def mlp_forward(params, x):
h = x
for i in range(1, len(layers)):
h = h @ params[f"layer{i}"]["w"] + params[f"layer{i}"]["b"]
if i < len(layers) - 1:
h = nb.relu(h)
return h
def loss_fn(params, x, y):
pred = mlp_forward(params, x)
diff = pred - y
return nb.mean(diff * diff)
@nb.compile
def train_step_compiled(params, x, y):
loss, grads = nb.value_and_grad(loss_fn)(params, x, y)
lr = 0.01
new_params = {}
for layer_name in params:
new_params[layer_name] = {
"w": params[layer_name]["w"] - grads[layer_name]["w"] * lr,
"b": params[layer_name]["b"] - grads[layer_name]["b"] * lr,
}
return loss, new_params
def train_step_eager(params, x, y):
loss, grads = nb.value_and_grad(loss_fn, realize=False)(params, x, y)
lr = 0.01
new_params = {}
for layer_name in params:
new_params[layer_name] = {
"w": params[layer_name]["w"] - grads[layer_name]["w"] * lr,
"b": params[layer_name]["b"] - grads[layer_name]["b"] * lr,
}
# Batch realize all outputs
all_outputs = [loss]
for layer_params in new_params.values():
all_outputs.extend(layer_params.values())
nb.realize_all(*all_outputs)
return loss, new_params
2. Nabla Benchmarks (Compiled vs Eager)#
[ ]:
print("=" * 70)
print("TEST 1: Compiled (@nb.compile)")
print("=" * 70)
params_compiled = params
n_steps = 200
loss, params_compiled = train_step_compiled(params_compiled, X, y)
print(f"Warmup: loss = {loss.to_numpy():.6f}")
start = time.perf_counter()
losses_compiled = []
for i in range(n_steps):
loss, params_compiled = train_step_compiled(params_compiled, X, y)
losses_compiled.append(float(loss.to_numpy()))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_compiled = time.perf_counter() - start
print(f"\nTime: {elapsed_compiled:.4f}s ({n_steps / elapsed_compiled:.1f} steps/sec)")
print(f"Loss: {losses_compiled[0]:.6f} -> {losses_compiled[-1]:.6f}")
print(f"Compile stats: {train_step_compiled.stats}")
print("\n" + "=" * 70)
print("TEST 2: Eager (no compile)")
print("=" * 70)
params_eager = params
loss, params_eager = train_step_eager(params_eager, X, y)
print(f"Warmup: loss = {loss.to_numpy():.6f}")
start = time.perf_counter()
losses_eager = []
for i in range(n_steps):
loss, params_eager = train_step_eager(params_eager, X, y)
losses_eager.append(float(loss.to_numpy()))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_eager = time.perf_counter() - start
print(f"\nTime: {elapsed_eager:.4f}s ({n_steps / elapsed_eager:.1f} steps/sec)")
print(f"Loss: {losses_eager[0]:.6f} -> {losses_eager[-1]:.6f}")
print("\n" + "=" * 70)
print("COMPARISON")
print("=" * 70)
speedup = elapsed_eager / elapsed_compiled
print(f"Speedup: {speedup:.2f}x with compile")
print(f" Compiled: {elapsed_compiled:.4f}s")
print(f" Eager: {elapsed_eager:.4f}s")
loss_diff = abs(losses_compiled[-1] - losses_eager[-1])
print(f"\nLoss difference: {loss_diff:.8f}")
if loss_diff < 1e-4:
print("✓ Compiled and eager match!")
else:
print("⚠ Compiled and eager differ!")
# Test 2b: Eager MAX Graph (builds MAX graph eagerly each step)
print()
print("=" * 70)
print("TEST 2b: Eager MAX Graph (EAGER_MAX_GRAPH=1)")
print("=" * 70)
import nabla.config as nabla_config
orig_eager_max = nabla_config.EAGER_MAX_GRAPH
nabla_config.EAGER_MAX_GRAPH = True
params_eager_max = params
loss, params_eager_max = train_step_eager(params_eager_max, X, y)
print(f"Warmup: loss = {loss.to_numpy():.6f}")
start = time.perf_counter()
losses_eager_max = []
for i in range(n_steps):
loss, params_eager_max = train_step_eager(params_eager_max, X, y)
losses_eager_max.append(float(loss.to_numpy()))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {loss.to_numpy():.6f}")
elapsed_eager_max = time.perf_counter() - start
nabla_config.EAGER_MAX_GRAPH = orig_eager_max
print(f"\nTime: {elapsed_eager_max:.4f}s ({n_steps / elapsed_eager_max:.1f} steps/sec)")
print(f"Loss: {losses_eager_max[0]:.6f} -> {losses_eager_max[-1]:.6f}")
loss_diff_max = abs(losses_compiled[-1] - losses_eager_max[-1])
print(f"Loss diff vs compiled: {loss_diff_max:.8f}")
if loss_diff_max < 1e-4:
print("✓ Eager MAX Graph and compiled match!")
else:
print("⚠ Eager MAX Graph and compiled differ!")
# Test 3: JAX JIT comparison
3. JAX Comparison (Optional)#
[ ]:
if HAS_JAX:
print()
print("=" * 70)
print("TEST 3: JAX with @jit (for comparison)")
print("=" * 70)
# Convert params to JAX format (flat list for simplicity)
jax_params = []
for layer_name in sorted(params.keys()):
w_np = params[layer_name]["w"].to_numpy()
b_np = params[layer_name]["b"].to_numpy()
jax_params.append(jnp.array(w_np))
jax_params.append(jnp.array(b_np))
X_jax = jnp.array(X_np)
y_jax = jnp.array(y_np)
def jax_mlp(params_flat, x):
h = x
for i in range(0, len(params_flat) - 2, 2):
h = h @ params_flat[i] + params_flat[i + 1]
h = jax.nn.relu(h)
h = h @ params_flat[-2] + params_flat[-1]
return h
def jax_loss(params_flat, x, y):
pred = jax_mlp(params_flat, x)
return jnp.mean((pred - y) ** 2)
@jit
def jax_train_step(params_flat, x, y):
loss = jax_loss(params_flat, x, y)
grads = grad(jax_loss)(params_flat, x, y)
lr = 0.01
new_params = [p - g * lr for p, g in zip(params_flat, grads, strict=False)]
return loss, new_params
# Warmup
loss_jax, jax_params = jax_train_step(jax_params, X_jax, y_jax)
jax.block_until_ready(loss_jax)
print(f"Warmup (trace): loss = {float(loss_jax):.6f}")
# Timed training
start = time.perf_counter()
losses_jax = []
for i in range(n_steps):
loss_jax, jax_params = jax_train_step(jax_params, X_jax, y_jax)
jax.block_until_ready(loss_jax)
losses_jax.append(float(loss_jax))
if (i + 1) % 50 == 0:
print(f" Step {i + 1:3d}: loss = {float(loss_jax):.6f}")
elapsed_jax = time.perf_counter() - start
print("\nJAX JIT version:")
print(f" Time: {elapsed_jax:.4f}s ({n_steps / elapsed_jax:.1f} steps/sec)")
print(f" Final loss: {losses_jax[-1]:.6f}")
print(
f" Loss reduction: {losses_jax[0]:.6f} -> {losses_jax[-1]:.6f} ({(1 - losses_jax[-1] / losses_jax[0]) * 100:.1f}% reduction)"
)
print()
print("=" * 70)
print("FINAL COMPARISON")
print("=" * 70)
print(
f"Nabla Compiled: {elapsed_compiled:.4f}s ({n_steps / elapsed_compiled:.1f} steps/sec)"
)
print(
f"Nabla Eager (deferred):{elapsed_eager:.4f}s ({n_steps / elapsed_eager:.1f} steps/sec)"
)
print(
f"Nabla Eager (MAX): {elapsed_eager_max:.4f}s ({n_steps / elapsed_eager_max:.1f} steps/sec)"
)
if HAS_JAX:
print(
f"JAX JIT: {elapsed_jax:.4f}s ({n_steps / elapsed_jax:.1f} steps/sec)"
)
print()
speedup_vs_jax = elapsed_jax / elapsed_compiled
if speedup_vs_jax > 1:
print(f"🚀 Nabla is {speedup_vs_jax:.2f}x FASTER than JAX!")
else:
print(f"JAX is {1 / speedup_vs_jax:.2f}x faster than Nabla")
print()
print(f"Nabla speedup over eager (deferred): {speedup:.2f}x")
print(
f"Nabla speedup over eager (MAX graph): {elapsed_eager_max / elapsed_compiled:.2f}x"
)
print()
4. Summary#
[ ]:
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print("✓ MLP training works with compile!")
print("✓ Full pytree parameters (weights + biases) work correctly")
print(
f"✓ Loss decreases properly: {losses_compiled[0]:.6f} -> {losses_compiled[-1]:.6f}"
)
print(f"✓ {speedup:.2f}x speedup from compilation")
print(f"✓ Cache hit rate: {train_step_compiled.stats.hit_rate:.1f}%")
if HAS_JAX:
print("✓ Compared against JAX JIT successfully")