Example 5b: Transformer Training (JAX-Style / Functional)#
This example builds the same sequence classification Transformer as 05a, but without nn.Module — everything is pure functions operating on parameter dicts (pytrees).
This style is closer to JAX/Flax and shows Nabla’s functional flexibility.
[ ]:
import numpy as np
import nabla as nb
import nabla.nn.functional as F
print("Nabla Transformer Training — JAX-style (functional)")
1. Parameter Initialization#
We initialize parameters as nested dicts. Each tensor that needs gradients has requires_grad = True.
[ ]:
def init_linear(in_dim: int, out_dim: int) -> dict:
"""Initialize a linear layer: {weight, bias}."""
params = {
"weight": F.xavier_normal((in_dim, out_dim)),
"bias": nb.zeros((1, out_dim)),
}
for p in params.values():
p.requires_grad = True
return params
def init_layer_norm(dim: int) -> dict:
"""Initialize layer norm: {weight, bias}."""
params = {
"weight": nb.ones((dim,)),
"bias": nb.zeros((dim,)),
}
for p in params.values():
p.requires_grad = True
return params
def init_mha(d_model: int) -> dict:
"""Initialize multi-head attention projections."""
return {
"q_proj": init_linear(d_model, d_model),
"k_proj": init_linear(d_model, d_model),
"v_proj": init_linear(d_model, d_model),
"out_proj": init_linear(d_model, d_model),
}
def init_encoder_layer(d_model: int, dim_ff: int) -> dict:
"""Initialize one Transformer encoder layer."""
return {
"attn": init_mha(d_model),
"norm1": init_layer_norm(d_model),
"norm2": init_layer_norm(d_model),
"ff1": init_linear(d_model, dim_ff),
"ff2": init_linear(dim_ff, d_model),
}
def init_transformer(
vocab_size: int,
d_model: int,
num_layers: int,
num_heads: int,
num_classes: int,
max_len: int,
dim_ff: int,
) -> dict:
"""Initialize all transformer parameters."""
# Embedding
emb_weight = F.xavier_normal((vocab_size, d_model))
emb_weight.requires_grad = True
# Positional encoding (fixed, not learned)
pe = np.zeros((max_len, d_model), dtype=np.float32)
pos = np.arange(0, max_len, dtype=np.float32)[:, np.newaxis]
div = np.exp(np.arange(0, d_model, 2, dtype=np.float32) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = np.sin(pos * div)
pe[:, 1::2] = np.cos(pos * div)
params = {
"embedding": emb_weight,
"layers": [init_encoder_layer(d_model, dim_ff) for _ in range(num_layers)],
"classifier": init_linear(d_model, num_classes),
"_num_heads": num_heads,
}
# Store PE as a non-differentiable constant
params["pe"] = nb.Tensor.from_dlpack(pe)
params["pe"].requires_grad = False
return params
2. Pure Function Layers#
Each layer is a pure function: output = layer(params, input).
[ ]:
def linear(params: dict, x):
"""Functional linear layer."""
return x @ params["weight"] + params["bias"]
def layer_norm(params: dict, x, eps: float = 1e-5):
"""Functional layer normalization."""
return F.layer_norm(x, weight=params["weight"], bias=params["bias"], eps=eps)
def multi_head_attention(params: dict, x, num_heads: int):
"""Functional multi-head self-attention.
Args:
params: Dict with q_proj, k_proj, v_proj, out_proj.
x: Input tensor (batch, seq_len, d_model).
num_heads: Number of attention heads.
"""
batch_size = x.shape[0]
seq_len = x.shape[1]
d_model = x.shape[2]
head_dim = d_model // num_heads
# Project to Q, K, V
q = linear(params["q_proj"], x) # (batch, seq, d_model)
k = linear(params["k_proj"], x)
v = linear(params["v_proj"], x)
# Reshape to (batch, num_heads, seq, head_dim)
q = nb.reshape(q, (batch_size, seq_len, num_heads, head_dim))
q = nb.permute(q, (0, 2, 1, 3))
k = nb.reshape(k, (batch_size, seq_len, num_heads, head_dim))
k = nb.permute(k, (0, 2, 1, 3))
v = nb.reshape(v, (batch_size, seq_len, num_heads, head_dim))
v = nb.permute(v, (0, 2, 1, 3))
# Scaled dot-product attention
attn_out = F.scaled_dot_product_attention(q, k, v, training=False)
# Reshape back: (batch, seq, d_model)
attn_out = nb.permute(attn_out, (0, 2, 1, 3))
attn_out = nb.reshape(attn_out, (batch_size, seq_len, d_model))
# Output projection
return linear(params["out_proj"], attn_out)
def encoder_layer(params: dict, x, num_heads: int):
"""Functional Transformer encoder layer (pre-norm)."""
# Self-attention with residual
normed = layer_norm(params["norm1"], x)
attn_out = multi_head_attention(params["attn"], normed, num_heads)
x = x + attn_out
# FFN with residual
normed = layer_norm(params["norm2"], x)
ff_out = linear(params["ff2"], nb.gelu(linear(params["ff1"], normed)))
x = x + ff_out
return x
def transformer_forward(params: dict, token_ids):
"""Full transformer forward pass.
Args:
params: Nested parameter dict from init_transformer.
token_ids: Integer tensor (batch, seq_len).
Returns:
Logits of shape (batch, num_classes).
"""
num_heads = params["_num_heads"]
# Token embedding + positional encoding
x = F.embedding(token_ids, params["embedding"])
seq_len = token_ids.shape[-1]
d_model = int(x.shape[-1])
pe = nb.slice_tensor(params["pe"], start=(0, 0), size=(seq_len, d_model))
x = x + pe
# Encoder layers
for layer_params in params["layers"]:
x = encoder_layer(layer_params, x, num_heads)
# Mean pooling + classify
x = nb.mean(x, axis=-2)
return linear(params["classifier"], x)
3. Create Data#
[ ]:
np.random.seed(42)
vocab_size = 20
seq_len = 8
num_classes = 3
n_samples = 150
d_model = 32
num_heads = 4
num_layers = 2
dim_ff = 64
# Random token sequences, labels = (sum of tokens) mod num_classes
token_ids_np = np.random.randint(0, vocab_size, (n_samples, seq_len)).astype(np.int64)
labels_np = (token_ids_np.sum(axis=1) % num_classes).astype(np.int64)
labels_onehot_np = np.zeros((n_samples, num_classes), dtype=np.float32)
labels_onehot_np[np.arange(n_samples), labels_np] = 1.0
token_ids = nb.Tensor.from_dlpack(token_ids_np)
labels = nb.Tensor.from_dlpack(labels_onehot_np)
print(f"Dataset: {n_samples} sequences of length {seq_len}")
print(f"Vocab: {vocab_size}, Classes: {num_classes}")
4. Initialize Model and Optimizer#
[ ]:
params = init_transformer(
vocab_size=vocab_size,
d_model=d_model,
num_layers=num_layers,
num_heads=num_heads,
num_classes=num_classes,
max_len=seq_len,
dim_ff=dim_ff,
)
opt_state = nb.nn.optim.adamw_init(params)
# Count parameters
from nabla import tree_leaves, Tensor
n_params = sum(p.numel() for p in tree_leaves(params) if isinstance(p, Tensor) and p.requires_grad)
print(f"Model: {num_layers} layers, d_model={d_model}, heads={num_heads}")
print(f"Total trainable parameters: {n_params}")
5. Training Loop#
[ ]:
def loss_fn(params, tokens, targets):
logits = transformer_forward(params, tokens)
return nb.nn.functional.cross_entropy_loss(logits, targets)
lr = 1e-3
num_epochs = 60
print(f"\n{'Epoch':<8} {'Loss':<12} {'Accuracy':<10}")
print("-" * 32)
for epoch in range(num_epochs):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, token_ids, labels)
params, opt_state = nb.nn.optim.adamw_update(
params, grads, opt_state, lr=lr
)
if (epoch + 1) % 10 == 0:
logits = transformer_forward(params, token_ids)
pred_classes = nb.argmax(logits, axis=-1)
target_classes = nb.Tensor.from_dlpack(labels_np.astype(np.int64))
correct = nb.equal(pred_classes, target_classes)
accuracy = nb.mean(nb.cast(correct, nb.DType.float32)).item()
print(f"{epoch + 1:<8} {loss.item():<12.4f} {accuracy:<10.2%}")
6. Compiled Training (Bonus)#
[ ]:
params2 = init_transformer(
vocab_size=vocab_size,
d_model=d_model,
num_layers=num_layers,
num_heads=num_heads,
num_classes=num_classes,
max_len=seq_len,
dim_ff=dim_ff,
)
opt_state2 = nb.nn.optim.adamw_init(params2)
@nb.compile
def compiled_step(params, opt_state, tokens, targets):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, tokens, targets)
params, opt_state = nb.nn.optim.adamw_update(
params, grads, opt_state, lr=1e-3
)
return params, opt_state, loss
print(f"\nCompiled training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
for step in range(30):
params2, opt_state2, loss = compiled_step(
params2, opt_state2, token_ids, labels
)
if (step + 1) % 10 == 0:
print(f"{step + 1:<8} {loss.item():<12.4f}")
Summary#
The functional style decomposes the Transformer into pure functions:
init_transformer(...)→ parameter pytreetransformer_forward(params, input)→ logitsloss_fn(params, ...)→ scalar lossvalue_and_grad(loss_fn)→ (loss, gradient pytree)adamw_update(params, grads, ...)→ (new_params, new_opt_state)
No mutation, no hidden state — everything flows through function arguments.
Congratulations! You’ve completed all the Nabla examples. You’re now equipped to build, train, and optimize ML models with Nabla’s dual API.
[ ]:
print("\n✅ Example 05b completed!")
print("🎉 All examples complete!")