Example 4a: MLP Training — PyTorch-Style (Imperative)#

Nabla supports two distinct training paradigms:

Paradigm

Gradient API

Optimizer API

PyTorch-style (this notebook)

loss.backward() + .grad

AdamW(model)optimizer.step()

JAX-style (4b)

nb.value_and_grad(fn)(args)

adamw_init + adamw_update

Here we demonstrate the PyTorch-style imperative API end-to-end. The training loop mirrors PyTorch exactly: zero_grad forward backward step.

[1]:
import time
import numpy as np

import nabla as nb

print("Nabla MLP Training — PyTorch-style")
Nabla MLP Training — PyTorch-style

1. Define the Model#

Subclass nb.nn.Module and define layers in __init__. The forward() method specifies the computation. Parameters (from nb.nn.Linear, etc.) are automatically registered and tracked.

[2]:
class MLP(nb.nn.Module):
    """Two-layer MLP with ReLU activation."""

    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.fc1 = nb.nn.Linear(in_dim, hidden_dim)
        self.fc2 = nb.nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = nb.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = MLP(4, 32, 1)
print(f"Model: fc1 {model.fc1.weight.shape}, fc2 {model.fc2.weight.shape}")
print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters())}")
Model: fc1 [Dim(4), Dim(32)], fc2 [Dim(32), Dim(1)]
Total trainable parameters: 193

2. Create Synthetic Data#

We’ll create a regression dataset: predict y = sin(x0) + cos(x1) + 0.5*x2 - x3.

[3]:
np.random.seed(42)
n_samples = 200
X_np = np.random.randn(n_samples, 4).astype(np.float32)
y_np = (
    np.sin(X_np[:, 0])
    + np.cos(X_np[:, 1])
    + 0.5 * X_np[:, 2]
    - X_np[:, 3]
).reshape(-1, 1).astype(np.float32)

X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print(f"Dataset: X {X.shape}, y {y.shape}")
Dataset: X [Dim(200), Dim(4)], y [Dim(200), Dim(1)]

3. Set Up the Stateful Optimizer#

nb.nn.optim.AdamW is a stateful optimizer — it holds references to the model parameters and maintains its own moment estimates (m, v). This is Nabla’s counterpart to torch.optim.AdamW.

JAX-style note: Nabla’s functional optimizer (nb.nn.optim.adamw_init

  • nb.nn.optim.adamw_update) takes params and optimizer state as explicit arguments and returns new values — no internal state at all. See 4b.

[4]:
optimizer = nb.nn.optim.AdamW(model, lr=1e-2)
print(f"Optimizer: AdamW (lr={optimizer.lr}, β=({optimizer.beta1}, {optimizer.beta2}))")
Optimizer: AdamW (lr=0.01, β=(0.9, 0.999))

4. PyTorch-Style Training Loop#

For comparability across notebooks, we use 60 training steps here too.

The four-step rhythm is identical to PyTorch:

  1. ``model.zero_grad()`` — clear accumulated .grad tensors from the previous iteration

  2. Forward pass — build the lazy computation graph

  3. ``loss.backward()`` — propagate gradients; every parameter with requires_grad=True gets its .grad populated and batch-realized

  4. ``optimizer.step()`` — read .grad from each parameter, apply the AdamW update, return the updated model

We also record timing:

  • total loop wall time

  • average milliseconds per training step

Lazy execution note: Because Nabla cannot mutate tensor data in-place without breaking the lazy graph, optimizer.step() returns the new model. Assign it back to model each iteration.

[5]:
num_epochs = 60
print(f"\n{'Epoch':<8} {'Loss':<12}")
print("-" * 22)

train_start = time.perf_counter()

for epoch in range(num_epochs):
    # 1. Clear gradients from the previous iteration
    model.zero_grad()

    # 2. Forward pass
    predictions = model(X)
    loss = nb.nn.functional.mse_loss(predictions, y)

    # 3. Backward pass — gradients stored in p.grad for each parameter
    loss.backward()

    # 4. Optimizer step — reads .grad, applies AdamW update
    model = optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"{epoch + 1:<8} {loss.item():<12.6f}")

train_elapsed = time.perf_counter() - train_start
train_step_ms = (train_elapsed / max(1, num_epochs)) * 1000.0
print(f"\nTotal training time: {train_elapsed:.4f} s")
print(f"Average step time:   {train_step_ms:.3f} ms/step")

Epoch    Loss
----------------------
10       0.625402
20       0.242037
30       0.159796
40       0.118573
50       0.082741
60       0.064866

Total training time: 2.0407 s
Average step time:   34.012 ms/step

5. Inspecting Gradients via .grad#

After loss.backward(), every trainable parameter exposes its gradient via .grad — exactly like PyTorch. The gradients are already realized (not lazy) when .backward() returns.

[6]:
# One more backward pass to show .grad access
model.train()
model.zero_grad()
b_loss = nb.nn.functional.mse_loss(model(X), y)
b_loss.backward()

import numpy as _np
print("Parameter gradients after backward():")
for name, param in model.named_parameters():
    g = param.grad
    if g is not None:
        g_np = _np.from_dlpack(g)
        print(f"  {name:30s}  shape {str(g.shape):<16}  |grad|={float(_np.linalg.norm(g_np)):.4f}")
Parameter gradients after backward():
  fc1.bias                        shape [Dim(1), Dim(32)]  |grad|=0.0387
  fc1.weight                      shape [Dim(4), Dim(32)]  |grad|=0.0944
  fc2.bias                        shape [Dim(1), Dim(1)]  |grad|=0.0257
  fc2.weight                      shape [Dim(32), Dim(1)]  |grad|=0.0916

6. Evaluate the Trained Model#

[7]:
model.eval()
final_loss = nb.nn.functional.mse_loss(model(X), y)
print(f"\nFinal MSE loss: {final_loss.item():.6f}")

predictions = model(X)
print(f"\n{'Prediction':>12}  {'Target':>12}")
print("-" * 28)
for i in range(5):
    pred_i = nb.gather(predictions, nb.constant(np.array([i], dtype=np.int64)), axis=0)
    true_i = nb.gather(y, nb.constant(np.array([i], dtype=np.int64)), axis=0)
    print(f"{pred_i.item():>12.4f}  {true_i.item():>12.4f}")

Final MSE loss: 0.063610

  Prediction        Target
----------------------------
      0.0546        0.2678
      0.6805        0.7629
      0.6607        0.6380
     -0.2755       -0.3964
      1.3211        1.0610

7. Contrast: JAX-Style API (for reference)#

The JAX-style equivalent of the same training step — note the absence of .backward(), .grad, and stateful optimizer mutations:

# JAX-style (functional) — see 04b_mlp_training_jax
def loss_fn(model, X, y):
    return nb.nn.functional.mse_loss(model(X), y)

# Single call computes both the loss value and all gradients
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(model, X, y)

# Functional optimizer — returns new model + new state (no mutation)
model, opt_state = nb.nn.optim.adamw_update(model, grads, opt_state, lr=1e-2)

Both paradigms are fully supported in Nabla:

  • PyTorch-style: familiar to PyTorch users, great for interactive debugging and stateful training loops

  • JAX-style: composable with nb.vmap, @nb.compile, nb.jacrev, etc.; required when nesting transforms or writing pure-functional pipelines

[8]:
print("\n✅ Example 04a completed!")

✅ Example 04a completed!

Summary#

PyTorch-Style API (this notebook)#

Concept

API

Define model

class MyModel(nb.nn.Module)

Linear layer

nb.nn.Linear(in_dim, out_dim)

Loss functions

nb.nn.functional.mse_loss, cross_entropy_loss

Clear gradients

model.zero_grad()

Compute gradients

loss.backward()

Inspect gradients

param.grad

Create optimizer

optimizer = nb.nn.optim.AdamW(model, lr=...)

Update parameters

model = optimizer.step()

JAX-Style API (see 4b)#

Concept

API

Compute loss + grads

loss, grads = nb.value_and_grad(fn, argnums=0)(model, ...)

Optimizer init

opt_state = nb.nn.optim.adamw_init(params)

Optimizer update

model, opt_state = nb.nn.optim.adamw_update(...)