Example 3b: MLP Training (JAX-Style / Functional)#
In this style, the model is a pure function that takes parameters explicitly. Parameters are stored in nested dicts (pytrees). This is the same approach used by JAX and Flax.
This example trains the same 2-layer MLP from Example 3a, but purely functionally.
[ ]:
import numpy as np
import nabla as nb
from nabla.nn.functional import xavier_normal
print("Nabla MLP Training — JAX-style (functional)")
1. Initialize Parameters#
Instead of a class, we create a nested dict of parameter tensors. Each tensor gets requires_grad=True so autodiff can track through it.
[ ]:
def init_mlp_params(in_dim: int, hidden_dim: int, out_dim: int) -> dict:
"""Initialize MLP parameters as a pytree (nested dict)."""
params = {
"fc1": {
"weight": xavier_normal((in_dim, hidden_dim)),
"bias": nb.zeros((1, hidden_dim)),
},
"fc2": {
"weight": xavier_normal((hidden_dim, out_dim)),
"bias": nb.zeros((1, out_dim)),
},
}
# Mark all params as differentiable
for layer in params.values():
for p in layer.values():
p.requires_grad = True
return params
params = init_mlp_params(4, 32, 1)
print("Parameter shapes:")
for name, layer in params.items():
for pname, p in layer.items():
print(f" {name}.{pname}: {p.shape}")
2. Define the Forward Pass#
The model is a pure function: it takes parameters and input, returns output. No side effects, no mutation.
[ ]:
def mlp_forward(params: dict, x):
"""Pure-function MLP forward pass."""
x = x @ params["fc1"]["weight"] + params["fc1"]["bias"]
x = nb.relu(x)
x = x @ params["fc2"]["weight"] + params["fc2"]["bias"]
return x
# Quick test
x_test = nb.uniform((3, 4))
y_test = mlp_forward(params, x_test)
print(f"Forward pass test: input {x_test.shape} → output {y_test.shape}")
3. Create Data & Define Loss#
Same synthetic dataset as Example 3a: y = sin(x0) + cos(x1) + 0.5*x2 - x3.
[ ]:
np.random.seed(42)
n_samples = 200
X_np = np.random.randn(n_samples, 4).astype(np.float32)
y_np = (
np.sin(X_np[:, 0])
+ np.cos(X_np[:, 1])
+ 0.5 * X_np[:, 2]
- X_np[:, 3]
).reshape(-1, 1).astype(np.float32)
X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print(f"Dataset: X {X.shape}, y {y.shape}")
def loss_fn(params, X, y):
"""MSE loss as a pure function of params."""
predictions = mlp_forward(params, X)
diff = predictions - y
return nb.mean(diff * diff)
initial_loss = loss_fn(params, X, y)
print(f"Initial loss: {initial_loss}")
4. Training Loop#
The key insight: value_and_grad(loss_fn, argnums=0) differentiates w.r.t. the first argument (params), which is a dict. It returns gradients with the same pytree structure as params.
[ ]:
opt_state = nb.nn.optim.adamw_init(params)
lr = 1e-2
num_epochs = 100
print(f"\n{'Epoch':<8} {'Loss':<12}")
print("-" * 22)
for epoch in range(num_epochs):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, X, y)
# grads has the same structure as params:
# grads["fc1"]["weight"], grads["fc1"]["bias"], etc.
params, opt_state = nb.nn.optim.adamw_update(
params, grads, opt_state, lr=lr
)
if (epoch + 1) % 10 == 0:
print(f"{epoch + 1:<8} {loss.item():<12.6f}")
5. Evaluation#
[ ]:
final_loss = loss_fn(params, X, y)
print(f"\nFinal loss: {final_loss}")
predictions = mlp_forward(params, X)
print(f"\nSample predictions vs targets:")
print(f"{'Prediction':<14} {'Target':<14}")
print("-" * 28)
for i in range(5):
pred_i = nb.gather(predictions, nb.constant(np.array([i], dtype=np.int64)), axis=0)
true_i = nb.gather(y, nb.constant(np.array([i], dtype=np.int64)), axis=0)
print(f"{pred_i.item():<14.4f} {true_i.item():<14.4f}")
6. Manual SGD (No Optimizer)#
The functional style makes it trivial to implement gradient descent manually using tree_map:
[ ]:
params_sgd = init_mlp_params(4, 32, 1)
sgd_lr = 0.05
print(f"\nManual SGD training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
for step in range(100):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params_sgd, X, y)
# Manual SGD: params = params - lr * grads
params_sgd = nb.tree_map(
lambda p, g: p - sgd_lr * g, params_sgd, grads
)
if (step + 1) % 20 == 0:
print(f"{step + 1:<8} {loss.item():<12.6f}")
PyTorch-Style vs JAX-Style: Comparison#
Aspect |
PyTorch-style (03a) |
JAX-style (03b) |
|---|---|---|
Model |
|
|
Params |
Auto-tracked by Module |
Explicit dict (pytree) |
State |
Mutable attributes |
Immutable, returned from functions |
Optimizer |
Can be stateful or functional |
Typically functional |
|
Works with both |
Works with both |
Both styles are fully supported in Nabla. Choose the one that fits your mental model!
Next: 04_transforms_and_compile — Advanced transforms (vmap, jacrev, jacfwd) and @nb.compile.
[ ]:
print("\n✅ Example 03b completed!")