Example 5b: Transformer Training (JAX-Style / Functional)#
This example builds the same sequence classification Transformer as 05a, but without nn.Module — everything is pure functions operating on parameter dicts (pytrees).
This style is closer to JAX/Flax and shows Nabla’s functional flexibility.
[1]:
import numpy as np
import nabla as nb
import nabla.nn.functional as F
print("Nabla Transformer Training — JAX-style (functional)")
Nabla Transformer Training — JAX-style (functional)
1. Parameter Initialization#
In the functional style, each layer is a dict of tensors (a pytree). We start with simple building blocks and compose them into a full model.
Primitive Layers#
[2]:
def init_linear(in_dim: int, out_dim: int) -> dict:
"""Initialize a linear layer: {weight, bias}."""
return {
"weight": F.xavier_normal((in_dim, out_dim)),
"bias": nb.zeros((1, out_dim)),
}
def init_layer_norm(dim: int) -> dict:
"""Initialize layer norm: {weight, bias}."""
return {"weight": nb.ones((dim,)), "bias": nb.zeros((dim,))}
Composite Layers#
Multi-head attention and encoder layers are just nested dicts of the primitives above:
[3]:
def init_mha(d_model: int) -> dict:
"""Initialize multi-head attention: Q, K, V, and output projections."""
return {
"q_proj": init_linear(d_model, d_model),
"k_proj": init_linear(d_model, d_model),
"v_proj": init_linear(d_model, d_model),
"out_proj": init_linear(d_model, d_model),
}
def init_encoder_layer(d_model: int, dim_ff: int) -> dict:
"""Initialize one Transformer encoder layer."""
return {
"attn": init_mha(d_model),
"norm1": init_layer_norm(d_model),
"norm2": init_layer_norm(d_model),
"ff1": init_linear(d_model, dim_ff),
"ff2": init_linear(dim_ff, d_model),
}
Full Model Initialization#
The top-level init_transformer returns the complete parameter tree. Notice how num_heads is passed as a function argument to transformer_forward (below), not stored in the params dict — keeping the pytree purely numeric:
[4]:
def init_transformer(vocab_size, d_model, num_layers, num_heads,
num_classes, max_len, dim_ff):
"""Initialize all transformer parameters as a nested dict."""
# Embedding
emb_weight = F.xavier_normal((vocab_size, d_model))
# Sinusoidal positional encoding (fixed, not learned)
pe = np.zeros((max_len, d_model), dtype=np.float32)
pos = np.arange(0, max_len, dtype=np.float32)[:, np.newaxis]
div = np.exp(np.arange(0, d_model, 2, dtype=np.float32) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = np.sin(pos * div)
pe[:, 1::2] = np.cos(pos * div)
return {
"embedding": emb_weight,
"pe": nb.Tensor.from_dlpack(pe), # non-differentiable constant
"layers": [init_encoder_layer(d_model, dim_ff) for _ in range(num_layers)],
"classifier": init_linear(d_model, num_classes),
}
2. Pure Function Layers#
Each layer is a pure function: output = layer(params, input). No hidden state, no mutation — just inputs in, outputs out.
Basic Building Blocks#
[5]:
def linear(params: dict, x):
"""Functional linear layer: y = xW + b."""
return x @ params["weight"] + params["bias"]
def layer_norm(params: dict, x, eps: float = 1e-5):
"""Functional layer normalization."""
return F.layer_norm(x, weight=params["weight"], bias=params["bias"], eps=eps)
Multi-Head Self-Attention#
The attention function projects input to Q, K, V, splits into heads, computes scaled dot-product attention, and concatenates the results:
[6]:
def multi_head_attention(params: dict, x, num_heads: int):
"""Functional multi-head self-attention."""
batch_size, seq_len, d_model = x.shape[0], x.shape[1], x.shape[2]
head_dim = d_model // num_heads
# Project to Q, K, V → (batch, seq, d_model)
q = linear(params["q_proj"], x)
k = linear(params["k_proj"], x)
v = linear(params["v_proj"], x)
# Reshape to (batch, heads, seq, head_dim)
def reshape_heads(t):
t = nb.reshape(t, (batch_size, seq_len, num_heads, head_dim))
return nb.permute(t, (0, 2, 1, 3))
q, k, v = reshape_heads(q), reshape_heads(k), reshape_heads(v)
# Scaled dot-product attention
attn_out = F.scaled_dot_product_attention(q, k, v, training=False)
# Concatenate heads → (batch, seq, d_model)
attn_out = nb.permute(attn_out, (0, 2, 1, 3))
attn_out = nb.reshape(attn_out, (batch_size, seq_len, d_model))
return linear(params["out_proj"], attn_out)
Encoder Layer and Full Forward Pass#
An encoder layer combines attention + feed-forward with residual connections and layer normalization (pre-norm variant). The full forward pass chains embedding → positional encoding → encoder stack → mean pooling → classifier:
[7]:
def encoder_layer(params: dict, x, num_heads: int):
"""Functional Transformer encoder layer (pre-norm)."""
# Self-attention + residual
normed = layer_norm(params["norm1"], x)
x = x + multi_head_attention(params["attn"], normed, num_heads)
# Feed-forward + residual
normed = layer_norm(params["norm2"], x)
x = x + linear(params["ff2"], nb.gelu(linear(params["ff1"], normed)))
return x
def transformer_forward(params: dict, token_ids, num_heads: int):
"""Full transformer forward pass: tokens → logits."""
# Embed + positional encoding
x = F.embedding(token_ids, params["embedding"])
seq_len = token_ids.shape[-1]
d_model = int(x.shape[-1])
pe = nb.slice_tensor(params["pe"], start=(0, 0), size=(seq_len, d_model))
x = x + pe
# Encoder layers
for layer_params in params["layers"]:
x = encoder_layer(layer_params, x, num_heads)
# Mean pooling + classify
return linear(params["classifier"], nb.mean(x, axis=-2))
3. Create Data#
[8]:
np.random.seed(42)
vocab_size = 20
seq_len = 8
num_classes = 3
n_samples = 150
d_model = 32
num_heads = 4
num_layers = 2
dim_ff = 64
# Random token sequences, labels = (sum of tokens) mod num_classes
token_ids_np = np.random.randint(0, vocab_size, (n_samples, seq_len)).astype(np.int64)
labels_np = (token_ids_np.sum(axis=1) % num_classes).astype(np.int64)
labels_onehot_np = np.zeros((n_samples, num_classes), dtype=np.float32)
labels_onehot_np[np.arange(n_samples), labels_np] = 1.0
token_ids = nb.Tensor.from_dlpack(token_ids_np)
labels = nb.Tensor.from_dlpack(labels_onehot_np)
print(f"Dataset: {n_samples} sequences of length {seq_len}")
print(f"Vocab: {vocab_size}, Classes: {num_classes}")
Dataset: 150 sequences of length 8
Vocab: 20, Classes: 3
4. Initialize Model and Optimizer#
[9]:
params = init_transformer(
vocab_size=vocab_size,
d_model=d_model,
num_layers=num_layers,
num_heads=num_heads,
num_classes=num_classes,
max_len=seq_len,
dim_ff=dim_ff,
)
opt_state = nb.nn.optim.adamw_init(params)
# Count parameters
from nabla import tree_leaves, Tensor
n_params = sum(p.numel() for p in tree_leaves(params) if isinstance(p, Tensor))
print(f"Model: {num_layers} layers, d_model={d_model}, heads={num_heads}")
print(f"Total trainable parameters: {n_params}")
Model: 2 layers, d_model=32, heads=4
Total trainable parameters: 18083
5. Training Loop#
[10]:
def loss_fn(params, tokens, targets):
logits = transformer_forward(params, tokens, num_heads=num_heads)
return nb.nn.functional.cross_entropy_loss(logits, targets)
[11]:
lr = 1e-3
num_epochs = 60
print(f"{'Epoch':<8} {'Loss':<12} {'Accuracy':<10}")
print("-" * 32)
for epoch in range(num_epochs):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, token_ids, labels)
params, opt_state = nb.nn.optim.adamw_update(params, grads, opt_state, lr=lr)
if (epoch + 1) % 10 == 0:
logits = transformer_forward(params, token_ids, num_heads=num_heads)
pred_classes = nb.argmax(logits, axis=-1)
target_classes = nb.Tensor.from_dlpack(labels_np.astype(np.int64))
correct = nb.equal(pred_classes, target_classes)
accuracy = nb.mean(nb.cast(correct, nb.DType.float32)).item()
print(f"{epoch + 1:<8} {loss.item():<12.4f} {accuracy:<10.2%}")
Epoch Loss Accuracy
--------------------------------
10 1.0926 44.67%
20 1.0410 48.67%
30 1.0012 49.33%
40 0.9542 53.33%
50 0.8927 62.00%
60 0.8154 64.67%
6. Compiled Training (Bonus)#
[12]:
@nb.compile
def compiled_step(params, opt_state, tokens, targets):
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, tokens, targets)
params, opt_state = nb.nn.optim.adamw_update(
params, grads, opt_state, lr=1e-3
)
return params, opt_state, loss
[13]:
# Fresh parameters for compiled training
params2 = init_transformer(
vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
num_heads=num_heads, num_classes=num_classes, max_len=seq_len, dim_ff=dim_ff,
)
opt_state2 = nb.nn.optim.adamw_init(params2)
print(f"Compiled training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)
for step in range(30):
params2, opt_state2, loss = compiled_step(params2, opt_state2, token_ids, labels)
if (step + 1) % 10 == 0:
print(f"{step + 1:<8} {loss.item():<12.4f}")
Compiled training:
Step Loss
----------------------
10 1.1183
20 1.0890
30 1.0441
Summary#
The functional style decomposes the Transformer into pure functions:
Function |
Role |
|---|---|
|
Creates the parameter pytree |
|
Pure forward pass |
|
Computes scalar loss |
|
Returns (loss, gradient pytree) |
|
Returns (new_params, new_opt_state) |
No mutation, no hidden state — everything flows through function arguments.