Example 3b: MLP Training (JAX-Style / Functional)#
In this style, the model is a pure function that takes parameters explicitly. Parameters are stored in nested dicts (pytrees). This is the same approach used by JAX and Flax.
This example trains the same 2-layer MLP from Example 3a, but purely functionally.
[1]:
import numpy as np
import nabla as nb
from nabla.nn.functional import xavier_normal
print("Nabla MLP Training — JAX-style (functional)")
Nabla MLP Training — JAX-style (functional)
1. Initialize Parameters#
Instead of a class, we create a nested dict of parameter tensors.
[2]:
def init_mlp_params(in_dim: int, hidden_dim: int, out_dim: int) -> dict:
"""Initialize MLP parameters as a pytree (nested dict)."""
params = {
"fc1": {
"weight": xavier_normal((in_dim, hidden_dim)),
"bias": nb.zeros((1, hidden_dim)),
},
"fc2": {
"weight": xavier_normal((hidden_dim, out_dim)),
"bias": nb.zeros((1, out_dim)),
},
}
return params
params = init_mlp_params(4, 32, 1)
print("Parameter shapes:")
for name, layer in params.items():
for pname, p in layer.items():
print(f" {name}.{pname}: {p.shape}")
Parameter shapes:
fc1.weight: [Dim(4), Dim(32)]
fc1.bias: [Dim(1), Dim(32)]
fc2.weight: [Dim(32), Dim(1)]
fc2.bias: [Dim(1), Dim(1)]
2. Define the Forward Pass#
The model is a pure function: it takes parameters and input, returns output. No side effects, no mutation.
[3]:
def mlp_forward(params: dict, x):
"""Pure-function MLP forward pass."""
x = x @ params["fc1"]["weight"] + params["fc1"]["bias"]
x = nb.relu(x)
x = x @ params["fc2"]["weight"] + params["fc2"]["bias"]
return x
# Quick test
x_test = nb.uniform((3, 4))
y_test = mlp_forward(params, x_test)
print(f"Forward pass test: input {x_test.shape} → output {y_test.shape}")
Forward pass test: input [Dim(3), Dim(4)] → output [Dim(3), Dim(1)]
3. Create Data & Define Loss#
Same synthetic dataset as Example 3a: y = sin(x0) + cos(x1) + 0.5*x2 - x3.
[4]:
np.random.seed(42)
n_samples = 200
X_np = np.random.randn(n_samples, 4).astype(np.float32)
y_np = (
np.sin(X_np[:, 0])
+ np.cos(X_np[:, 1])
+ 0.5 * X_np[:, 2]
- X_np[:, 3]
).reshape(-1, 1).astype(np.float32)
X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print(f"Dataset: X {X.shape}, y {y.shape}")
def loss_fn(params, X, y):
"""MSE loss as a pure function of params."""
predictions = mlp_forward(params, X)
diff = predictions - y
return nb.mean(diff * diff)
initial_loss = loss_fn(params, X, y)
print(f"Initial loss: {initial_loss}")
Dataset: X [Dim(200), Dim(4)], y [Dim(200), Dim(1)]
Initial loss: Tensor(2.5328 : f32[])
4. Training Loop#
The key insight: value_and_grad(loss_fn, argnums=0) differentiates w.r.t. the first argument (params), which is a dict. It returns gradients with the same pytree structure as params.
[5]:
opt_state = nb.nn.optim.adamw_init(params)
lr = 1e-2
num_epochs = 100
print(f"\n{'Epoch':<8} {'Loss':<12}")
print("-" * 22)
for epoch in range(num_epochs):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, X, y)
# grads has the same structure as params:
# grads["fc1"]["weight"], grads["fc1"]["bias"], etc.
params, opt_state = nb.nn.optim.adamw_update(
params, grads, opt_state, lr=lr
)
if (epoch + 1) % 10 == 0:
print(f"{epoch + 1:<8} {loss.item():<12.6f}")
Epoch Loss
----------------------
10 0.625402
20 0.242037
30 0.159796
40 0.118573
50 0.082741
60 0.064866
70 0.054502
80 0.047352
90 0.041529
100 0.037468
5. Evaluation#
[6]:
final_loss = loss_fn(params, X, y)
print(f"\nFinal loss: {final_loss}")
predictions = mlp_forward(params, X)
print(f"\nSample predictions vs targets:")
print(f"{'Prediction':<14} {'Target':<14}")
print("-" * 28)
for i in range(5):
pred_i = nb.gather(predictions, nb.constant(np.array([i], dtype=np.int64)), axis=0)
true_i = nb.gather(y, nb.constant(np.array([i], dtype=np.int64)), axis=0)
print(f"{pred_i.item():<14.4f} {true_i.item():<14.4f}")
Final loss: Tensor(0.0371 : f32[])
Sample predictions vs targets:
Prediction Target
----------------------------
0.1029 0.2678
0.7215 0.7629
0.7124 0.6380
-0.3226 -0.3964
1.2237 1.0610
6. Manual SGD (No Optimizer)#
The functional style makes it trivial to implement gradient descent manually using tree_map:
[7]:
params_sgd = init_mlp_params(4, 32, 1)
sgd_lr = 0.05
print(f"\nManual SGD training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
for step in range(100):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params_sgd, X, y)
# Manual SGD: params = params - lr * grads
params_sgd = nb.tree_map(
lambda p, g: p - sgd_lr * g, params_sgd, grads
)
if (step + 1) % 20 == 0:
print(f"{step + 1:<8} {loss.item():<12.6f}")
Manual SGD training:
Step Loss
----------------------
20 0.204635
40 0.145309
60 0.117855
80 0.095323
100 0.077824
PyTorch-Style vs JAX-Style: Comparison#
Aspect |
PyTorch-style (03a) |
JAX-style (03b) |
|---|---|---|
Model |
|
|
Params |
Auto-tracked by Module |
Explicit dict (pytree) |
State |
Mutable attributes |
Immutable, returned from functions |
Optimizer |
Can be stateful or functional |
Typically functional |
|
Works with both |
Works with both |
Both styles are fully supported in Nabla. Choose the one that fits your mental model!