Example 4b: MLP Training (JAX-Style / Functional)#

In this style, the model is a pure function that takes parameters explicitly. Parameters are stored in nested dicts (pytrees). This is the same approach used by JAX and Flax.

This example trains the same 2-layer MLP from Example 4a, but purely functionally.

[1]:
import time
import numpy as np

import nabla as nb
from nabla.nn.functional import xavier_normal

print("Nabla MLP Training — JAX-style (functional)")
Nabla MLP Training — JAX-style (functional)

1. Initialize Parameters#

Instead of a class, we create a nested dict of parameter tensors.

[2]:
def init_mlp_params(in_dim: int, hidden_dim: int, out_dim: int) -> dict:
    """Initialize MLP parameters as a pytree (nested dict)."""
    params = {
        "fc1": {
            "weight": xavier_normal((in_dim, hidden_dim)),
            "bias": nb.zeros((1, hidden_dim)),
        },
        "fc2": {
            "weight": xavier_normal((hidden_dim, out_dim)),
            "bias": nb.zeros((1, out_dim)),
        },
    }
    return params


params = init_mlp_params(4, 32, 1)
print("Parameter shapes:")
for name, layer in params.items():
    for pname, p in layer.items():
        print(f"  {name}.{pname}: {p.shape}")
Parameter shapes:
  fc1.weight: [Dim(4), Dim(32)]
  fc1.bias: [Dim(1), Dim(32)]
  fc2.weight: [Dim(32), Dim(1)]
  fc2.bias: [Dim(1), Dim(1)]

2. Define the Forward Pass#

The model is a pure function: it takes parameters and input, returns output. No side effects, no mutation.

[3]:
def mlp_forward(params: dict, x):
    """Pure-function MLP forward pass."""
    x = x @ params["fc1"]["weight"] + params["fc1"]["bias"]
    x = nb.relu(x)
    x = x @ params["fc2"]["weight"] + params["fc2"]["bias"]
    return x


# Quick test
x_test = nb.uniform((3, 4))
y_test = mlp_forward(params, x_test)
print(f"Forward pass test: input {x_test.shape} → output {y_test.shape}")
Forward pass test: input [Dim(3), Dim(4)] → output [Dim(3), Dim(1)]

3. Create Data & Define Loss#

Same synthetic dataset as Example 4a: y = sin(x0) + cos(x1) + 0.5*x2 - x3.

[4]:
np.random.seed(42)
n_samples = 200
X_np = np.random.randn(n_samples, 4).astype(np.float32)
y_np = (
    np.sin(X_np[:, 0])
    + np.cos(X_np[:, 1])
    + 0.5 * X_np[:, 2]
    - X_np[:, 3]
).reshape(-1, 1).astype(np.float32)

X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print(f"Dataset: X {X.shape}, y {y.shape}")


def loss_fn(params, X, y):
    """MSE loss as a pure function of params."""
    predictions = mlp_forward(params, X)
    diff = predictions - y
    return nb.mean(diff * diff)

initial_loss = loss_fn(params, X, y)
print(f"Initial loss: {initial_loss}")
Dataset: X [Dim(200), Dim(4)], y [Dim(200), Dim(1)]
Initial loss: Tensor(2.5328 : f32[])

4. Training Loop#

The key insight: value_and_grad(loss_fn, argnums=0) differentiates w.r.t. the first argument (params), which is a dict. It returns gradients with the same pytree structure as params.

[5]:
opt_state = nb.nn.optim.adamw_init(params)
lr = 1e-2
num_epochs = 60

print(f"\n{'Epoch':<8} {'Loss':<12}")
print("-" * 22)

train_start = time.perf_counter()
for epoch in range(num_epochs):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, X, y)

    # grads has the same structure as params:
    # grads["fc1"]["weight"], grads["fc1"]["bias"], etc.
    params, opt_state = nb.nn.optim.adamw_update(
        params, grads, opt_state, lr=lr
    )

    if (epoch + 1) % 10 == 0:
        print(f"{epoch + 1:<8} {loss.item():<12.6f}")
train_elapsed = time.perf_counter() - train_start
print(f"\nTotal training time: {train_elapsed:.4f} s")
print(f"Average step time:   {(train_elapsed / max(1, num_epochs)) * 1000.0:.3f} ms/step")

Epoch    Loss
----------------------
10       0.625402
20       0.242037
30       0.159796
40       0.118573
50       0.082741
60       0.064866

Total training time: 1.3741 s
Average step time:   22.901 ms/step

5. Evaluation#

[6]:
final_loss = loss_fn(params, X, y)
print(f"\nFinal loss: {final_loss}")

predictions = mlp_forward(params, X)
print(f"\nSample predictions vs targets:")
print(f"{'Prediction':<14} {'Target':<14}")
print("-" * 28)
for i in range(5):
    pred_i = nb.gather(predictions, nb.constant(np.array([i], dtype=np.int64)), axis=0)
    true_i = nb.gather(y, nb.constant(np.array([i], dtype=np.int64)), axis=0)
    print(f"{pred_i.item():<14.4f} {true_i.item():<14.4f}")

Final loss: Tensor(0.0636 : f32[])

Sample predictions vs targets:
Prediction     Target
----------------------------
0.0546         0.2678
0.6805         0.7629
0.6607         0.6380
-0.2755        -0.3964
1.3211         1.0610

6. Manual SGD (No Optimizer)#

The functional style makes it trivial to implement gradient descent manually using tree_map:

[7]:
params_sgd = init_mlp_params(4, 32, 1)
sgd_lr = 0.05

print(f"\nManual SGD training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)

sgd_steps = 60
sgd_start = time.perf_counter()
for step in range(sgd_steps):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params_sgd, X, y)

    # Manual SGD: params = params - lr * grads
    params_sgd = nb.tree_map(
        lambda p, g: p - sgd_lr * g, params_sgd, grads
    )

    if (step + 1) % 20 == 0:
        print(f"{step + 1:<8} {loss.item():<12.6f}")
sgd_elapsed = time.perf_counter() - sgd_start
print(f"\nManual SGD time: {sgd_elapsed:.4f} s")
print(f"Manual SGD avg step: {(sgd_elapsed / max(1, sgd_steps)) * 1000.0:.3f} ms/step")

Manual SGD training:
Step     Loss
----------------------
20       0.204635
40       0.145309
60       0.117855

Manual SGD time: 1.7942 s
Manual SGD avg step: 29.903 ms/step

PyTorch-Style vs JAX-Style: Comparison#

Aspect

PyTorch-style (4a)

JAX-style (4b)

Model

class MLP(nn.Module)

def mlp_forward(params, x)

Params

Auto-tracked by Module

Explicit dict (pytree)

State

Mutable attributes

Immutable, returned from functions

Optimizer

Can be stateful or functional

Typically functional

@nb.compile

Works with both

Works with both

Both styles are fully supported in Nabla. Choose the one that fits your mental model!