MLP Training (GPU)#
In this tutorial, we’ll walk through how to use Nabla with GPU acceleration to train a neural network to learn a complex sin function. We’ll cover installation, device setup, and the training loop with jitting for GPU acceleration.
Installation and Setup#
[4]:
# Installation
import sys
IN_COLAB = "google.colab" in sys.modules
try:
import nabla as nb
except ImportError:
import subprocess
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"modular",
"--extra-index-url",
"https://download.pytorch.org/whl/cpu",
"--index-url",
"https://dl.modular.com/public/nightly/python/simple/",
],
check=True,
)
subprocess.run(
[sys.executable, "-m", "pip", "install", "nabla-ml", "--upgrade"], check=True
)
import nabla as nb
# Import other required libraries
import time
import numpy as np
print(
f"🎉 Nabla is ready! Running on Python {sys.version_info.major}.{sys.version_info.minor}"
)
🎉 Nabla is ready! Running on Python 3.10
Introduction to Nabla and GPU Acceleration#
Nabla is a deep learning library that leverages the Modular MLIR compiler for high-performance computation. One of its key features is GPU acceleration, which is achieved through jitting (Just-In-Time compilation). This means that functions decorated with @nb.jit
are compiled into optimized GPU code by the Modular compiler.
Why to(device)
?#
In Nabla, tensors need to be explicitly moved to the desired device (CPU or GPU) using the to(device)
method. This is because Nabla’s GPU mode is only accessible within jitted functions. The device is determined at runtime, and tensors must be on the correct device for operations to execute efficiently.
Key Concepts:#
Jitting: Functions decorated with
@nb.jit
are compiled and optimized for GPU execution.Device Placement: Tensors must be moved to the appropriate device using
to(device)
.Training Loop: The training loop involves creating datasets, computing gradients, and updating parameters using an optimizer.
Let’s dive into the implementation!
[5]:
# Configuration
BATCH_SIZE = 4
LAYERS = [1, 64, 128, 256, 128, 64, 1]
LEARNING_RATE = 0.001
NUM_EPOCHS = 1000
PRINT_INTERVAL = 100
SIN_PERIODS = 8
device = nb.cpu() if nb.accelerator_count() == 0 else nb.accelerator()
print(f"Using {device} device")
Using Device(type=gpu,id=0) device
[6]:
def mlp_forward(x: nb.Array, params: list[nb.Array]) -> nb.Array:
"""MLP forward pass through all layers."""
output = x
for i in range(0, len(params) - 1, 2):
w, b = params[i], params[i + 1]
output = nb.matmul(output, w) + b
# Apply ReLU to all layers except the last
if i < len(params) - 2:
output = nb.relu(output)
return output
def mean_squared_error(predictions: nb.Array, targets: nb.Array) -> nb.Array:
"""Compute mean squared error loss."""
diff = predictions - targets
squared_errors = diff * diff
batch_size = nb.array(predictions.shape[0], dtype=nb.DType.float32).to(device)
loss = nb.sum(squared_errors) / batch_size
return loss
def mlp_forward_and_loss(inputs: list[nb.Array]) -> nb.Array:
"""Combined forward pass and loss computation for VJP with leaky ReLU."""
x, targets, *params = inputs
predictions = mlp_forward(x, params)
loss = mean_squared_error(predictions, targets)
return loss
[7]:
def create_sin_dataset(batch_size: int = 256) -> tuple[nb.Array, nb.Array]:
"""Create the COMPLEX 8-period sin dataset."""
x = nb.rand((batch_size, 1), lower=0.0, upper=1.0, dtype=nb.DType.float32).to(
device
)
targets = nb.sin(SIN_PERIODS * 2.0 * np.pi * x) / 2.0 + 0.5
return x, targets
def initialize_for_complex_function(
layers: list[int], seed: int = 42
) -> list[nb.Array]:
"""Initialize specifically for learning complex high-frequency functions."""
np.random.seed(seed)
params = []
for i in range(len(layers) - 1):
fan_in, fan_out = layers[i], layers[i + 1]
w = nb.he_normal((fan_in, fan_out), seed=seed).to(device)
b = nb.zeros((fan_out,)).to(device)
params.append(w)
params.append(b)
return params
[8]:
def adamw_step(
params: list[nb.Array],
gradients: list[nb.Array],
m_states: list[nb.Array],
v_states: list[nb.Array],
step: int,
learning_rate: float = 0.001,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
weight_decay: float = 0.01,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array]]:
"""AdamW optimizer step with weight decay - OPTIMIZED to match JAX efficiency."""
updated_params = []
updated_m = []
updated_v = []
for param, grad, m, v in zip(params, gradients, m_states, v_states, strict=False):
# Update moments
new_m = beta1 * m + (1.0 - beta1) * grad
new_v = beta2 * v + (1.0 - beta2) * (grad * grad)
# Bias correction
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
# Corrected moments
m_corrected = new_m / bias_correction1
v_corrected = new_v / bias_correction2
# Parameter update with weight decay
new_param = param - learning_rate * (
m_corrected / (v_corrected**0.5 + eps) + weight_decay * param
)
# Append updated values
updated_params.append(new_param)
updated_m.append(new_m)
updated_v.append(new_v)
return updated_params, updated_m, updated_v
def init_adamw_state(params: list[nb.Array]) -> tuple[list[nb.Array], list[nb.Array]]:
"""Initialize AdamW state - optimized version."""
m_states = []
v_states = []
for param in params:
# Use zeros_like for more efficient initialization
m_np = np.zeros_like(param.to_numpy())
v_np = np.zeros_like(param.to_numpy())
m_states.append(nb.Array.from_numpy(m_np).to(device))
v_states.append(nb.Array.from_numpy(v_np).to(device))
return m_states, v_states
[9]:
def learning_rate_schedule(
epoch: int,
initial_lr: float = 0.001,
decay_factor: float = 0.95,
decay_every: int = 1000,
) -> float:
"""Learning rate schedule for complex function learning."""
return initial_lr * (decay_factor ** (epoch // decay_every))
[10]:
@nb.jit(show_graph=False)
def train_step(
x: nb.Array,
targets: nb.Array,
params: list[nb.Array],
m_states: list[nb.Array],
v_states: list[nb.Array],
step: int,
learning_rate: float,
) -> tuple[list[nb.Array], list[nb.Array], list[nb.Array], nb.Array]:
"""JIT-compiled training step combining gradient computation and optimizer update."""
# Define loss function that takes separate arguments (JAX style)
def loss_fn(*inner_params):
predictions = mlp_forward(x, inner_params)
loss = mean_squared_error(predictions, targets)
return loss
loss_value, param_gradients = nb.value_and_grad(
loss_fn, argnums=list(range(len(params)))
)(*params)
# AdamW optimizer update
updated_params, updated_m, updated_v = adamw_step(
params, param_gradients, m_states, v_states, step, learning_rate
)
return updated_params, updated_m, updated_v, loss_value
[11]:
@nb.jit
def compute_predictions_and_loss(
x_test: nb.Array, targets_test: nb.Array, params: list[nb.Array]
) -> tuple[nb.Array, nb.Array]:
"""JIT-compiled function to compute predictions and loss."""
predictions_test = mlp_forward(x_test, params)
test_loss = mean_squared_error(predictions_test, targets_test)
return predictions_test, test_loss
[12]:
def test_nabla_complex_sin():
"""Test Nabla implementation with JIT for complex sin learning."""
print("=== Learning COMPLEX 8-Period Sin Function with Nabla JIT ===")
print(f"Architecture: {LAYERS}")
print(f"Initial learning rate: {LEARNING_RATE}")
print(f"Sin periods: {SIN_PERIODS}")
print(f"Batch size: {BATCH_SIZE}")
# Initialize for complex function learning
params = initialize_for_complex_function(LAYERS)
m_states, v_states = init_adamw_state(params)
# Initial analysis
x_init, targets_init = create_sin_dataset(BATCH_SIZE)
predictions_init = mlp_forward(x_init, params)
initial_loss = mean_squared_error(predictions_init, targets_init)
pred_init_np = predictions_init.to_numpy()
target_init_np = targets_init.to_numpy()
print(f"Initial loss: {initial_loss.to_numpy().item():.6f}")
print(
f"Initial predictions range: [{pred_init_np.min():.3f}, {pred_init_np.max():.3f}]"
)
print(f"Targets range: [{target_init_np.min():.3f}, {target_init_np.max():.3f}]")
print("\nStarting training...")
# Training loop
avg_loss = 0.0
avg_time = 0.0
avg_data_time = 0.0
avg_vjp_time = 0.0
avg_adamw_time = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
epoch_start_time = time.time()
# Learning rate schedule
current_lr = learning_rate_schedule(epoch, LEARNING_RATE)
# Create fresh batch
data_start = time.time()
x, targets = create_sin_dataset(BATCH_SIZE)
data_time = time.time() - data_start
# Training step using JIT-compiled function
vjp_start = time.time()
# Use JIT-compiled training step (combines gradient computation and optimizer update)
updated_params, updated_m, updated_v, loss_values = train_step(
x, targets, params, m_states, v_states, epoch, current_lr
)
vjp_time = time.time() - vjp_start
# Update return values (no separate AdamW step needed)
params, m_states, v_states = updated_params, updated_m, updated_v
adamw_time = 0.0 # Already included in the JIT step
# Loss extraction and conversion
loss_value = loss_values.to_numpy().item()
epoch_time = time.time() - epoch_start_time
avg_loss += loss_value
avg_time += epoch_time
avg_data_time += data_time
avg_vjp_time += vjp_time
avg_adamw_time += adamw_time
if epoch % PRINT_INTERVAL == 0:
print(f"\n{'=' * 60}")
print(
f"Epoch {epoch:3d} | Loss: {avg_loss / PRINT_INTERVAL:.6f} | Time: {avg_time / PRINT_INTERVAL:.4f}s"
)
print(f"{'=' * 60}")
print(
f" ├─ Data Gen: {avg_data_time / PRINT_INTERVAL:.4f}s ({avg_data_time / avg_time * 100:.1f}%)"
)
print(
f" └─ JIT Step: {avg_vjp_time / PRINT_INTERVAL:.4f}s ({avg_vjp_time / avg_time * 100:.1f}%)"
)
avg_loss = 0.0
avg_time = 0.0
avg_data_time = 0.0
avg_vjp_time = 0.0
avg_adamw_time = 0.0
print("\nNabla JIT training completed!")
# Final evaluation
print("\n=== Final Evaluation ===")
x_test_np = np.linspace(0, 1, 1000).reshape(-1, 1).astype(np.float32)
targets_test_np = (
np.sin(SIN_PERIODS * 2.0 * np.pi * x_test_np) / 2.0 + 0.5
).astype(np.float32)
x_test = nb.Array.from_numpy(x_test_np).to(device)
targets_test = nb.Array.from_numpy(targets_test_np).to(device)
# Use JIT-compiled function for evaluation
predictions_test, test_loss = compute_predictions_and_loss(
x_test, targets_test, params
)
pred_final_np = predictions_test.to_numpy()
final_test_loss = test_loss.to_numpy().item()
print(f"Final test loss: {final_test_loss:.6f}")
print(
f"Final predictions range: [{pred_final_np.min():.3f}, {pred_final_np.max():.3f}]"
)
print(f"Target range: [{targets_test_np.min():.3f}, {targets_test_np.max():.3f}]")
# Calculate correlation
correlation = np.corrcoef(pred_final_np.flatten(), targets_test_np.flatten())[0, 1]
print(f"Prediction-target correlation: {correlation:.4f}")
return final_test_loss, correlation
if __name__ == "__main__":
final_loss, correlation = test_nabla_complex_sin()
print("\n=== Nabla JIT Summary ===")
print(f"Final test loss: {final_loss:.6f}")
print(f"Correlation with true function: {correlation:.4f}")
if correlation > 0.95:
print("SUCCESS: Nabla JIT learned the complex function very well! 🎉")
elif correlation > 0.8:
print("GOOD: Nabla JIT learned the general shape well! 👍")
elif correlation > 0.5:
print("PARTIAL: Some learning but needs improvement 🤔")
else:
print("POOR: Nabla JIT failed to learn the complex function 😞")
=== Learning COMPLEX 8-Period Sin Function with Nabla JIT ===
Architecture: [1, 64, 128, 256, 128, 64, 1]
Initial learning rate: 0.001
Sin periods: 8
Batch size: 4
Initial loss: 2.015263
Initial predictions range: [-1.115, -0.850]
Targets range: [0.008, 0.887]
Starting training...
The Kernel crashed while executing code in the current cell or a previous cell.
Please review the code in the cell(s) to identify a possible cause of the failure.
Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info.
View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
Summary#
In this tutorial, we covered:
Installation: Setting up Nabla with GPU support in Google Colab.
Device Setup: Understanding and using
to(device)
for GPU acceleration.Training Loop: Implementing a neural network to learn a complex sin function with Nabla’s jitting for GPU acceleration.
By following this tutorial, you should now have a good understanding of how to use Nabla for GPU-accelerated deep learning tasks.
Note
💡 Want to run this yourself?
🚀 Google Colab: No setup required, runs in your browser
📥 Local Jupyter: Download and run with your own Python environment