Training an MLP (CPU and GPU)#

In this notebook, we’ll train a simple Multi-Layer Perceptron (MLP) to learn a sine wave function. The primary goal is to demonstrate and compare Nabla’s different execution modes, showcasing how Just-In-Time (JIT) compilation can dramatically accelerate model training.

Note: To create a heavy computational workload and clearly demonstrate performance gains, the MLP used here is intentionally oversized. A much smaller network would be sufficient and probably better for this task, but the larger model allows us to stress-test the JIT compiler and clearly see the speedup from hardware acceleration.

1. Imports and Setup#

First, we’ll import the necessary libraries: nabla for deep learning, numpy, matplotlib for plotting, and time for benchmarking.

[1]:
import sys
import time
from typing import Any, Callable, List

try:
    import matplotlib.pyplot as plt
    import numpy as np

    import nabla as nb
except ImportError:
    import subprocess

    packages = ["numpy", "nabla-ml", "matplotlib"]
    subprocess.run([sys.executable, "-m", "pip", "install"] + packages, check=True)
    import matplotlib.pyplot as plt
    import numpy as np

    import nabla as nb

print(
    f"🎉 All libraries loaded successfully! Python {sys.version_info.major}.{sys.version_info.minor}"
)
🎉 All libraries loaded successfully! Python 3.10

2. Configuration#

Here we define the key hyperparameters for our model architecture and the training process.

[2]:
# Model & Task Configuration
LAYERS = [1, 256, 1024, 2048, 2048, 1024, 256, 1]  # Defines the architecture of our MLP
SIN_PERIODS = 8  # The complexity of the sine wave we want to learn

# Training Configuration
BATCH_SIZE = 512
LEARNING_RATE = 0.001 # Using a constant learning rate
NUM_EPOCHS = 500
PRINT_INTERVAL = 100

3. Model and Loss Function#

  • mlp_forward: Defines the forward pass of our MLP.

  • mean_squared_error: Our loss function.

[3]:
def mlp_forward(x: nb.Array, params: list[nb.Array]) -> nb.Array:
    """MLP forward pass through all layers."""
    output = x
    for i in range(0, len(params) - 1, 2):
        w, b = params[i], params[i + 1]
        output = nb.matmul(output, w) + b
        if i < len(params) - 2:
            output = nb.relu(output)
    return output


def mean_squared_error(predictions: nb.Array, targets: nb.Array) -> nb.Array:
    """Compute mean squared error loss."""
    diff = predictions - targets
    squared_errors = diff * diff
    batch_size = nb.array(predictions.shape[0], dtype=nb.DType.float32)
    loss = nb.sum(squared_errors) / batch_size
    return loss

4. Data Generation and Parameter Initialization#

  • create_sin_dataset: Generates a batch of training data.

  • initialize_mlp_params: Creates and initializes the model’s parameters.

[4]:
def create_sin_dataset(batch_size: int) -> tuple[nb.Array, nb.Array]:
    """Create a batch of data for the n-period sine function."""
    x_np = np.random.rand(batch_size, 1).astype(np.float32)
    targets_np = (np.sin(SIN_PERIODS * 2.0 * np.pi * x_np) / 2.0 + 0.5).astype(
        np.float32
    )
    return nb.Array.from_numpy(x_np), nb.Array.from_numpy(targets_np)


def initialize_mlp_params(layers: list[int], seed: int = 42) -> list[nb.Array]:
    """Initialize MLP parameters using He normal initialization."""
    params = []
    for i in range(len(layers) - 1):
        w = nb.he_normal((layers[i], layers[i + 1]), seed=seed)
        b = nb.zeros((layers[i + 1],))
        params.append(w)
        params.append(b)
    return params

5. Optimizer (AdamW)#

This is the original, simple implementation of the AdamW optimizer, without gradient clipping.

[5]:
def init_adamw_state(
    params: list[nb.Array],
) -> tuple[list[nb.Array], list[nb.Array]]:
    """
    Initializes the momentum (m) and variance (v) states for the AdamW optimizer.

    Each state is initialized as a zero-tensor with the same shape as its
    corresponding parameter.
    """
    m_states = [nb.zeros(p.shape) for p in params]
    v_states = [nb.zeros(p.shape) for p in params]
    return m_states, v_states


def adamw_step(
    params: list[nb.Array],
    gradients: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
    learning_rate: float,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    weight_decay: float = 0.01,
    max_grad_norm: float = 1.0,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array]]:
    """
    Performs a single step of the AdamW optimizer with a robust gradient clipping implementation.
    """

    # --- Corrected Gradient Clipping (mirroring the Transformer logic) ---
    # 1. Calculate the total norm.
    total_norm_sq = sum(nb.sum(g * g) for g in gradients)
    total_norm = nb.sqrt(total_norm_sq)

    # 2. Use nb.minimum to get a single, simple scaling factor. This is the key change.
    clip_factor = nb.minimum(1.0, max_grad_norm / (total_norm + 1e-8))

    # 3. Apply the single factor to all gradients. This creates a much simpler graph.
    clipped_gradients = [g * clip_factor for g in gradients]

    # --- Standard AdamW Update Logic ---
    updated_params = []
    updated_m = []
    updated_v = []

    for param, grad, m, v in zip(params, clipped_gradients, m_states, v_states):
        new_m = beta1 * m + (1.0 - beta1) * grad
        new_v = beta2 * v + (1.0 - beta2) * (grad * grad)

        m_corr = new_m / (1.0 - beta1**step)
        v_corr = new_v / (1.0 - beta2**step)

        new_param = param - learning_rate * (
            m_corr / (nb.sqrt(v_corr) + eps) + weight_decay * param
        )

        updated_params.append(new_param)
        updated_m.append(new_m)
        updated_v.append(new_v)

    return updated_params, updated_m, updated_v

6. Defining the Core Training Step#

This function encapsulates a single training step. We are using the original, working value_and_grad pattern that takes a variable number of arguments, which is known to be correct for this simpler setup.

[6]:
def _complete_training_step(
    x: nb.Array,
    targets: nb.Array,
    params: list[nb.Array],
    m_states: list[nb.Array],
    v_states: list[nb.Array],
    step: int,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array], nb.Array]:
    """The core, undecorated logic for a single training step."""

    def loss_fn(*inner_params):
        predictions = mlp_forward(x, list(inner_params))
        return mean_squared_error(predictions, targets)

    loss_value, param_gradients = nb.value_and_grad(
        loss_fn, argnums=list(range(len(params)))
    )(*params)

    updated_params, updated_m, updated_v = adamw_step(
        params, param_gradients, m_states, v_states, step, LEARNING_RATE
    )

    return updated_params, updated_m, updated_v, loss_value

7. Defining the Execution Modes#

Now we create our three different training functions to compare performance.

[7]:
# Mode 1: Eager step (just an alias to the original function)
eager_training_step = _complete_training_step

# Mode 2: JIT-compiled step for CPU execution
jit_cpu_training_step = nb.jit(_complete_training_step, auto_device=False)

# Mode 3: (Optional) JIT-compiled step with automatic device placement
if nb.accelerator_count() > 0:
    jit_accelerator_training_step = nb.jit(_complete_training_step, auto_device=True)
    print(f"✅ Accelerator detected! '{nb.accelerator()}' will be used for the third run.")
else:
    jit_accelerator_training_step = None
    print("INFO: No accelerator detected. The accelerator-based training run will be skipped.")
✅ Accelerator detected! 'Device(type=gpu,id=0)' will be used for the third run.

8. Training and Benchmarking Loop#

This generic run_training_loop function executes the training process for a given step function, handling initialization, the main loop, and timing.

[8]:
def run_training_loop(
    mode_name: str, training_step_fn: Callable[..., Any]
) -> dict[str, Any]:
    """Generic training loop to benchmark a given training step function."""
    print("=" * 60)
    print(f"🤖 TRAINING MLP IN {mode_name.upper()} MODE")
    print("=" * 60)

    params = initialize_mlp_params(LAYERS)
    m_states, v_states = init_adamw_state(params)

    print("🔥 Warming up (3 steps)...")
    for i in range(3):
        x, targets = create_sin_dataset(BATCH_SIZE)
        params, m_states, v_states, _ = training_step_fn(
            x, targets, params, m_states, v_states, i + 1
        )
    print("✅ Warmup complete! Starting timed training...\n")

    start_time = time.time()
    loss_history = []
    time_history = [start_time]

    for epoch in range(1, NUM_EPOCHS + 1):
        x, targets = create_sin_dataset(BATCH_SIZE)
        params, m_states, v_states, loss = training_step_fn(
            x, targets, params, m_states, v_states, epoch
        )
        loss_history.append(float(loss.to_numpy()))
        time_history.append(time.time())

        if epoch % PRINT_INTERVAL == 0:
            avg_loss = np.mean(loss_history[-PRINT_INTERVAL:])
            print(f"Epoch {epoch:5d} | Avg Loss: {avg_loss:.6f} | Time: {time_history[-1] - start_time}")

    total_time = time_history[-1] - start_time
    print(f"\n{mode_name} Training complete! Total time: {total_time:.2f}s")

    return {
        "params": params,
        "loss_history": loss_history,
        "time_history": time_history,
        "total_time": total_time,
    }


# --- Run all training modes ---
results = {}

results["Eager"] = run_training_loop("Eager", eager_training_step)
results["JIT (CPU)"] = run_training_loop("JIT (CPU)", jit_cpu_training_step)

if jit_accelerator_training_step:
    results["JIT (Accelerator)"] = run_training_loop(
        "JIT (Accelerator)", jit_accelerator_training_step
    )
============================================================
🤖 TRAINING MLP IN EAGER MODE
============================================================
🔥 Warming up (3 steps)...
✅ Warmup complete! Starting timed training...

Epoch   100 | Avg Loss: 1.960438 | Time: 21.644721269607544
Epoch   200 | Avg Loss: 0.101309 | Time: 42.536442279815674
Epoch   300 | Avg Loss: 0.092507 | Time: 63.919116258621216
Epoch   400 | Avg Loss: 0.083474 | Time: 84.2975525856018
Epoch   500 | Avg Loss: 0.085233 | Time: 105.92598056793213

✅ Eager Training complete! Total time: 105.93s
============================================================
🤖 TRAINING MLP IN JIT (CPU) MODE
============================================================
🔥 Warming up (3 steps)...
✅ Warmup complete! Starting timed training...

Epoch   100 | Avg Loss: 1.970259 | Time: 4.292894124984741
Epoch   200 | Avg Loss: 0.100773 | Time: 8.325119733810425
Epoch   300 | Avg Loss: 0.090923 | Time: 12.33992314338684
Epoch   400 | Avg Loss: 0.086043 | Time: 16.240766763687134
Epoch   500 | Avg Loss: 0.083217 | Time: 20.310790538787842

✅ JIT (CPU) Training complete! Total time: 20.31s
============================================================
🤖 TRAINING MLP IN JIT (ACCELERATOR) MODE
============================================================
🔥 Warming up (3 steps)...
✅ Warmup complete! Starting timed training...

Epoch   100 | Avg Loss: 1.780899 | Time: 0.30668044090270996
Epoch   200 | Avg Loss: 0.100820 | Time: 0.5730481147766113
Epoch   300 | Avg Loss: 0.090581 | Time: 0.8387882709503174
Epoch   400 | Avg Loss: 0.084966 | Time: 1.1043381690979004
Epoch   500 | Avg Loss: 0.083767 | Time: 1.3704478740692139

✅ JIT (Accelerator) Training complete! Total time: 1.37s

9. Performance Comparison#

Let’s visualize the results. The ‘Loss vs. Time’ plot remains the best way to see the raw speed advantage of JIT compilation.

[9]:
def plot_performance_comparison(results: dict[str, Any]):
    """Plot the loss curves for all completed training runs."""
    if not results:
        print("No results to plot.")
        return

    plt.figure(figsize=(16, 7))

    # Plot loss vs epochs
    plt.subplot(1, 2, 1)
    for mode, data in results.items():
        plt.plot(data["loss_history"], label=f'{mode} (Total: {data["total_time"]:.2f}s)')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss vs. Epochs")
    plt.yscale("log")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot loss vs time
    plt.subplot(1, 2, 2)
    for mode, data in results.items():
        relative_times = [t - data["time_history"][0] for t in data["time_history"]]
        plt.plot(relative_times[1:], data["loss_history"], label=mode)

    if 'JIT (CPU)' in results:
        xlim_max = results['JIT (CPU)']['total_time'] * 1.05
        plt.xlim(0, xlim_max)

    plt.xlabel("Time (seconds)")
    plt.ylabel("Loss")
    plt.title("Loss vs. Time (Zoomed In)")
    plt.yscale("log")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.show()


plot_performance_comparison(results)
../_images/tutorials_mlp_training_18_0.png

10. Final Evaluation & Visualization#

Finally, we take the model from the fastest training run and see how well it learned the sine wave. We plot its predictions against the true function for a clear visual evaluation.

[10]:
def evaluate_and_plot_model(params: List[nb.Array]):
    """Evaluates the final model and plots its predictions against the ground truth."""
    print("\n" + "=" * 60)
    print("🧪 FINAL MODEL EVALUATION")
    print("=" * 60)

    # Create a high-resolution test set for a smooth plot
    x_test_np = np.linspace(0, 1, 1000).reshape(-1, 1).astype(np.float32)
    targets_test_np = (np.sin(SIN_PERIODS * 2.0 * np.pi * x_test_np) / 2.0 + 0.5).astype(
        np.float32
    )
    x_test = nb.Array.from_numpy(x_test_np)

    # Get predictions from the model
    predictions_test = mlp_forward(x_test, params)
    pred_final_np = predictions_test.to_numpy()

    # Calculate final loss and correlation
    final_loss = np.mean((pred_final_np - targets_test_np) ** 2)
    correlation = np.corrcoef(pred_final_np.flatten(), targets_test_np.flatten())[0, 1]

    print(f"Final Test Loss: {final_loss:.6f}")
    print(f"Prediction-Target Correlation: {correlation:.4f}")

    # Plot the results
    plt.figure(figsize=(12, 6))
    plt.plot(x_test_np, targets_test_np, label='Ground Truth', color='blue', linestyle='--')
    plt.plot(x_test_np, pred_final_np, label='Model Prediction', color='red', alpha=0.8)
    plt.title('Final Model Prediction vs. Ground Truth')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()


# Find the best parameters from the fastest run and evaluate
if results:
    fastest_mode = min(results, key=lambda mode: results[mode]["total_time"])
    print(f"\nEvaluating model from the fastest run: '{fastest_mode}'")
    best_params = results[fastest_mode]['params']
    evaluate_and_plot_model(best_params)
else:
    print("\nSkipping evaluation as no training runs were completed.")

Evaluating model from the fastest run: 'JIT (Accelerator)'

============================================================
🧪 FINAL MODEL EVALUATION
============================================================
Final Test Loss: 0.100042
Prediction-Target Correlation: 0.4524
../_images/tutorials_mlp_training_20_1.png

11. Conclusion#

This experiment successfully demonstrates a key principle: for large computational workloads, hardware acceleration is a necessity. Although the training dynamics were unstable, the performance results are unambiguous. The nearly 15x speedup is a powerful illustration of how accelerators make training large models feasible in practice.