Example 6a: Transformer Training — PyTorch-Style (Imperative)#
Nabla supports two training paradigms; this notebook demonstrates the PyTorch-style imperative API:
Paradigm |
Gradient API |
Optimizer API |
|---|---|---|
PyTorch-style (this notebook) |
|
|
JAX-style (6b) |
|
|
We build a small Transformer encoder for a synthetic sequence classification task using nb.nn.TransformerEncoderLayer, Embedding, and MultiHeadAttention.
[1]:
import numpy as np
import nabla as nb
import time
print("Nabla Transformer Training — PyTorch-style")
Nabla Transformer Training — PyTorch-style
1. Positional Encoding#
We’ll use sinusoidal positional encoding, computed as a fixed buffer.
[2]:
def make_positional_encoding(max_len: int, d_model: int) -> np.ndarray:
"""Sinusoidal positional encoding."""
pe = np.zeros((max_len, d_model), dtype=np.float32)
position = np.arange(0, max_len, dtype=np.float32)[:, np.newaxis]
div_term = np.exp(
np.arange(0, d_model, 2, dtype=np.float32) * -(np.log(10000.0) / d_model)
)
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe # (max_len, d_model)
2. Define the Model#
Our TransformerClassifier is an nb.nn.Module subclass with these components:
Component |
Purpose |
|---|---|
|
Maps token IDs → dense vectors |
Sinusoidal PE |
Encodes position information (fixed, not learned) |
|
Self-attention + feed-forward blocks |
|
Classification head |
The __init__ method creates these components; forward chains them together.
[3]:
class TransformerClassifier(nb.nn.Module):
"""Transformer encoder for sequence classification."""
def __init__(self, vocab_size, d_model, num_heads, num_layers,
num_classes, max_len=128, dim_feedforward=128):
super().__init__()
self.d_model = d_model
# --- Embeddings ---
self.embedding = nb.nn.Embedding(vocab_size, d_model)
pe_np = make_positional_encoding(max_len, d_model)
self.pe = nb.Tensor.from_dlpack(pe_np) # fixed, not learned
# --- Encoder stack ---
self.layers = []
for i in range(num_layers):
layer = nb.nn.TransformerEncoderLayer(
d_model=d_model, num_heads=num_heads,
dim_feedforward=dim_feedforward, dropout=0.0,
)
setattr(self, f"encoder_{i}", layer)
self.layers.append(layer)
# --- Classifier ---
self.classifier = nb.nn.Linear(d_model, num_classes)
def forward(self, token_ids):
# Embed + positional encoding
x = self.embedding(token_ids)
seq_len = token_ids.shape[-1]
pe = nb.slice_tensor(self.pe, start=(0, 0), size=(seq_len, self.d_model))
x = x + pe
# Encoder layers
for layer in self.layers:
x = layer(x)
# Mean pool + classify
return self.classifier(nb.mean(x, axis=-2))
3. Create Synthetic Data#
Generate a simple classification task:
Sequences of random token IDs
Labels based on a rule (e.g., majority token determines class)
[4]:
np.random.seed(42)
vocab_size = 20
seq_len = 8
num_classes = 3
n_samples = 150
d_model = 32
num_heads = 4
num_layers = 2
# Generate random token sequences
token_ids_np = np.random.randint(0, vocab_size, (n_samples, seq_len)).astype(np.int64)
# Labels: class = (sum of tokens) mod num_classes
labels_np = (token_ids_np.sum(axis=1) % num_classes).astype(np.int64)
# One-hot encode labels
labels_onehot_np = np.zeros((n_samples, num_classes), dtype=np.float32)
labels_onehot_np[np.arange(n_samples), labels_np] = 1.0
token_ids = nb.Tensor.from_dlpack(token_ids_np)
labels = nb.Tensor.from_dlpack(labels_onehot_np)
print(f"Dataset: {n_samples} sequences of length {seq_len}")
print(f"Vocab size: {vocab_size}, Classes: {num_classes}")
print(f"Sample tokens: {token_ids_np[0]}")
print(f"Sample label: {labels_np[0]}")
Dataset: 150 sequences of length 8
Vocab size: 20, Classes: 3
Sample tokens: [ 6 19 14 10 7 6 18 10]
Sample label: 0
4. Build Model and Optimizer#
Important initialization order for the stateful optimizer: Create AdamW while the model is in train mode (_training=True, the default). Nabla’s Module pytree includes _training in its metadata, so the optimizer’s internal moment tensors (m, v) are snapshot-initialized with that training mode. Calling model.eval() before creating the optimizer would bake _training=False into those snapshots, causing a pytree metadata mismatch the first time model.train()
is called inside the training loop.
Rule of thumb:
Stateful optimizer (
AdamW(model)) → create in train mode, callmodel.eval()only for eval passes.Functional optimizer (
adamw_init(model)) →model.eval()beforeadamw_initso every pass shares the same_training=Falsestate.
[5]:
model = TransformerClassifier(
vocab_size=vocab_size,
d_model=d_model,
num_heads=num_heads,
num_layers=num_layers,
num_classes=num_classes,
max_len=seq_len,
dim_feedforward=64,
)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model: {num_layers} encoder layers, d_model={d_model}, heads={num_heads}")
print(f"Total trainable parameters: {n_params}")
# Create optimizer while model is in train mode (default _training=True)
model.train()
optimizer = nb.nn.optim.AdamW(model, lr=1e-3)
print(f"Optimizer: AdamW (lr={optimizer.lr})")
Model: 2 encoder layers, d_model=32, heads=4
Total trainable parameters: 17827
Optimizer: AdamW (lr=0.001)
5. PyTorch-Style Training Loop#
Imperative four-step loop: zero_grad → forward → backward → step
model.train()at the top of each iteration ensures the model is in train mode beforeoptimizer.step().loss.backward()populates.gradon everyrequires_grad=Trueparameter and batch-realizes all gradients before returning.optimizer.step()(no arguments) reads.grad, applies the AdamW update, and returns the updated model.Assigning
model = optimizer.step()is necessary because Nabla’s lazy execution cannot mutate tensor data truly in-place.
For comparability, we use 60 training steps and record:
total eager loop time
average milliseconds per eager step
[6]:
num_epochs = 60
print(f"\n{'Epoch':<8} {'Loss':<12} {'Accuracy':<10}")
print("-" * 32)
eager_train_start = time.perf_counter()
for epoch in range(num_epochs):
model.train() # ensures train mode before optimizer step
model.zero_grad() # clear .grad from previous iteration
# Forward pass
logits = model(token_ids)
loss = nb.nn.functional.cross_entropy_loss(logits, labels)
# Backward pass — fills .grad on all trainable parameters
loss.backward()
# Optimizer step — reads .grad, applies AdamW, returns updated model
model = optimizer.step()
if (epoch + 1) % 10 == 0:
model.eval()
logits_eval = model(token_ids)
pred_classes = nb.argmax(logits_eval, axis=-1)
target_classes = nb.Tensor.from_dlpack(labels_np.astype(np.int64))
correct = nb.equal(pred_classes, target_classes)
accuracy = nb.mean(nb.cast(correct, nb.DType.float32)).item()
print(f"{epoch + 1:<8} {loss.item():<12.4f} {accuracy:<10.2%}")
eager_train_elapsed = time.perf_counter() - eager_train_start
eager_train_step_ms = (eager_train_elapsed / max(1, num_epochs)) * 1000.0
print(f"\nEager PyTorch-style training time: {eager_train_elapsed:.4f} s")
print(f"Eager PyTorch-style avg step: {eager_train_step_ms:.3f} ms/step")
Epoch Loss Accuracy
--------------------------------
10 3.0606 30.67%
20 2.6096 30.67%
30 2.2841 30.00%
40 2.0109 32.00%
50 1.7453 32.00%
60 1.5089 32.67%
Eager PyTorch-style training time: 6.3128 s
Eager PyTorch-style avg step: 105.213 ms/step
6. Compiled Training (Bonus)#
@nb.compile runs the same training-step function with cached compiled execution when input metadata matches (shape, dtype, sharding, structure).
API note: Inside a compiled function,
value_and_grad(functional transform) must be used — notloss.backward(). The imperative.backward()/.gradpath is for eager execution only.
Speedup interpretation in this notebook is simple:
compiled cached runs remove most Python overhead
eager runs keep Python in the step loop
[7]:
import time
# Fresh model + functional optimizer state for compiled run
model_c = TransformerClassifier(
vocab_size=vocab_size, d_model=d_model, num_heads=num_heads,
num_layers=num_layers, num_classes=num_classes,
max_len=seq_len, dim_feedforward=64,
)
# eval() BEFORE adamw_init so both share _training=False pytree structure
model_c.eval()
opt_state_c = nb.nn.optim.adamw_init(model_c)
def loss_fn_for_compile(model, tokens, targets):
logits = model(tokens)
return nb.nn.functional.cross_entropy_loss(logits, targets)
@nb.compile
def compiled_step(model, opt_state, tokens, targets):
loss, grads = nb.value_and_grad(loss_fn_for_compile, argnums=0)(
model, tokens, targets
)
model, opt_state = nb.nn.optim.adamw_update(model, grads, opt_state, lr=1e-3)
return model, opt_state, loss
def eager_step(model, opt_state, tokens, targets):
loss, grads = nb.value_and_grad(loss_fn_for_compile, argnums=0)(
model, tokens, targets
)
model, opt_state = nb.nn.optim.adamw_update(model, grads, opt_state, lr=1e-3)
return model, opt_state, loss
Run the compiled timing loop.
For fair comparison, this uses the same number of steps as eager training (num_epochs = 60).
[8]:
# Use the same step count as the eager loop for comparability
n_timed_steps = num_epochs
print(f"\nCompiled training (functional API inside @nb.compile):")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
# 1) First compiled call includes trace+compile overhead
compile_start = time.perf_counter()
model_c, opt_state_c, loss_c = compiled_step(model_c, opt_state_c, token_ids, labels)
first_compiled_ms = (time.perf_counter() - compile_start) * 1000.0
# 2) Cached compiled execution timing (same loop length as eager)
cached_start = time.perf_counter()
for step in range(1, n_timed_steps):
model_c, opt_state_c, loss_c = compiled_step(model_c, opt_state_c, token_ids, labels)
if (step + 1) % 10 == 0:
print(f"{step + 1:<8} {loss_c.item():<12.4f}")
cached_elapsed = time.perf_counter() - cached_start
cached_step_ms = (cached_elapsed / max(1, n_timed_steps - 1)) * 1000.0
print("\nCompiled cache stats:", compiled_step.stats)
print(f"First compiled call (trace+compile): {first_compiled_ms:.2f} ms")
print(f"Cached compiled step avg: {cached_step_ms:.2f} ms")
# 3) Eager functional baseline timing (same math, no @nb.compile)
model_e = TransformerClassifier(
vocab_size=vocab_size, d_model=d_model, num_heads=num_heads,
num_layers=num_layers, num_classes=num_classes,
max_len=seq_len, dim_feedforward=64,
)
model_e.eval()
opt_state_e = nb.nn.optim.adamw_init(model_e)
# one warmup
model_e, opt_state_e, _ = eager_step(model_e, opt_state_e, token_ids, labels)
eager_start = time.perf_counter()
for _ in range(n_timed_steps - 1):
model_e, opt_state_e, loss_e = eager_step(model_e, opt_state_e, token_ids, labels)
eager_elapsed = time.perf_counter() - eager_start
eager_step_ms = (eager_elapsed / max(1, n_timed_steps - 1)) * 1000.0
speedup = eager_step_ms / max(cached_step_ms, 1e-9)
print(f"Eager functional step avg: {eager_step_ms:.2f} ms")
print(f"Compiled cached speedup vs eager: {speedup:.2f}x")
Compiled training (functional API inside @nb.compile):
Step Loss
----------------------
10 1.1420
20 1.0677
30 1.0248
40 0.9630
50 0.8735
60 0.7395
Compiled cache stats: CompilationStats(hits=59, misses=1, fallbacks=0, hit_rate=98.3%)
First compiled call (trace+compile): 1047.90 ms
Cached compiled step avg: 10.83 ms
Eager functional step avg: 157.98 ms
Compiled cached speedup vs eager: 14.59x
Summary#
PyTorch-Style (Eager) — This Notebook#
Component |
API |
|---|---|
Token embedding |
|
Transformer layer |
|
Fixed buffer |
|
Training mode |
|
Clear gradients |
|
Compute gradients |
|
Parameter update |
|
JAX-Style (Functional) — See 6b#
Concept |
API |
|---|---|
Model state |
Nested dict pytree |
Compute loss + grads |
|
Optimizer |
|
Compiled training |
|