Example 5b: Transformer Training (JAX-Style / Functional)#

This example builds the same sequence classification Transformer as 05a, but without nn.Module — everything is pure functions operating on parameter dicts (pytrees).

This style is closer to JAX/Flax and shows Nabla’s functional flexibility.

[1]:
import numpy as np

import nabla as nb
import nabla.nn.functional as F

print("Nabla Transformer Training — JAX-style (functional)")
Nabla Transformer Training — JAX-style (functional)

1. Parameter Initialization#

In the functional style, each layer is a dict of tensors (a pytree). We start with simple building blocks and compose them into a full model.

Primitive Layers#

[2]:
def init_linear(in_dim: int, out_dim: int) -> dict:
    """Initialize a linear layer: {weight, bias}."""
    return {
        "weight": F.xavier_normal((in_dim, out_dim)),
        "bias": nb.zeros((1, out_dim)),
    }


def init_layer_norm(dim: int) -> dict:
    """Initialize layer norm: {weight, bias}."""
    return {"weight": nb.ones((dim,)), "bias": nb.zeros((dim,))}

Composite Layers#

Multi-head attention and encoder layers are just nested dicts of the primitives above:

[3]:
def init_mha(d_model: int) -> dict:
    """Initialize multi-head attention: Q, K, V, and output projections."""
    return {
        "q_proj": init_linear(d_model, d_model),
        "k_proj": init_linear(d_model, d_model),
        "v_proj": init_linear(d_model, d_model),
        "out_proj": init_linear(d_model, d_model),
    }


def init_encoder_layer(d_model: int, dim_ff: int) -> dict:
    """Initialize one Transformer encoder layer."""
    return {
        "attn": init_mha(d_model),
        "norm1": init_layer_norm(d_model),
        "norm2": init_layer_norm(d_model),
        "ff1": init_linear(d_model, dim_ff),
        "ff2": init_linear(dim_ff, d_model),
    }

Full Model Initialization#

The top-level init_transformer returns the complete parameter tree. Notice how num_heads is passed as a function argument to transformer_forward (below), not stored in the params dict — keeping the pytree purely numeric:

[4]:
def init_transformer(vocab_size, d_model, num_layers, num_heads,
                     num_classes, max_len, dim_ff):
    """Initialize all transformer parameters as a nested dict."""
    # Embedding
    emb_weight = F.xavier_normal((vocab_size, d_model))

    # Sinusoidal positional encoding (fixed, not learned)
    pe = np.zeros((max_len, d_model), dtype=np.float32)
    pos = np.arange(0, max_len, dtype=np.float32)[:, np.newaxis]
    div = np.exp(np.arange(0, d_model, 2, dtype=np.float32) * -(np.log(10000.0) / d_model))
    pe[:, 0::2] = np.sin(pos * div)
    pe[:, 1::2] = np.cos(pos * div)

    return {
        "embedding": emb_weight,
        "pe": nb.Tensor.from_dlpack(pe),  # non-differentiable constant
        "layers": [init_encoder_layer(d_model, dim_ff) for _ in range(num_layers)],
        "classifier": init_linear(d_model, num_classes),
    }

2. Pure Function Layers#

Each layer is a pure function: output = layer(params, input). No hidden state, no mutation — just inputs in, outputs out.

Basic Building Blocks#

[5]:
def linear(params: dict, x):
    """Functional linear layer: y = xW + b."""
    return x @ params["weight"] + params["bias"]


def layer_norm(params: dict, x, eps: float = 1e-5):
    """Functional layer normalization."""
    return F.layer_norm(x, weight=params["weight"], bias=params["bias"], eps=eps)

Multi-Head Self-Attention#

The attention function projects input to Q, K, V, splits into heads, computes scaled dot-product attention, and concatenates the results:

[6]:
def multi_head_attention(params: dict, x, num_heads: int):
    """Functional multi-head self-attention."""
    batch_size, seq_len, d_model = x.shape[0], x.shape[1], x.shape[2]
    head_dim = d_model // num_heads

    # Project to Q, K, V  →  (batch, seq, d_model)
    q = linear(params["q_proj"], x)
    k = linear(params["k_proj"], x)
    v = linear(params["v_proj"], x)

    # Reshape to (batch, heads, seq, head_dim)
    def reshape_heads(t):
        t = nb.reshape(t, (batch_size, seq_len, num_heads, head_dim))
        return nb.permute(t, (0, 2, 1, 3))

    q, k, v = reshape_heads(q), reshape_heads(k), reshape_heads(v)

    # Scaled dot-product attention
    attn_out = F.scaled_dot_product_attention(q, k, v, training=False)

    # Concatenate heads → (batch, seq, d_model)
    attn_out = nb.permute(attn_out, (0, 2, 1, 3))
    attn_out = nb.reshape(attn_out, (batch_size, seq_len, d_model))

    return linear(params["out_proj"], attn_out)

Encoder Layer and Full Forward Pass#

An encoder layer combines attention + feed-forward with residual connections and layer normalization (pre-norm variant). The full forward pass chains embedding → positional encoding → encoder stack → mean pooling → classifier:

[7]:
def encoder_layer(params: dict, x, num_heads: int):
    """Functional Transformer encoder layer (pre-norm)."""
    # Self-attention + residual
    normed = layer_norm(params["norm1"], x)
    x = x + multi_head_attention(params["attn"], normed, num_heads)

    # Feed-forward + residual
    normed = layer_norm(params["norm2"], x)
    x = x + linear(params["ff2"], nb.gelu(linear(params["ff1"], normed)))
    return x


def transformer_forward(params: dict, token_ids, num_heads: int):
    """Full transformer forward pass: tokens → logits."""
    # Embed + positional encoding
    x = F.embedding(token_ids, params["embedding"])
    seq_len = token_ids.shape[-1]
    d_model = int(x.shape[-1])
    pe = nb.slice_tensor(params["pe"], start=(0, 0), size=(seq_len, d_model))
    x = x + pe

    # Encoder layers
    for layer_params in params["layers"]:
        x = encoder_layer(layer_params, x, num_heads)

    # Mean pooling + classify
    return linear(params["classifier"], nb.mean(x, axis=-2))

3. Create Data#

[8]:
np.random.seed(42)

vocab_size = 20
seq_len = 8
num_classes = 3
n_samples = 150
d_model = 32
num_heads = 4
num_layers = 2
dim_ff = 64

# Random token sequences, labels = (sum of tokens) mod num_classes
token_ids_np = np.random.randint(0, vocab_size, (n_samples, seq_len)).astype(np.int64)
labels_np = (token_ids_np.sum(axis=1) % num_classes).astype(np.int64)
labels_onehot_np = np.zeros((n_samples, num_classes), dtype=np.float32)
labels_onehot_np[np.arange(n_samples), labels_np] = 1.0

token_ids = nb.Tensor.from_dlpack(token_ids_np)
labels = nb.Tensor.from_dlpack(labels_onehot_np)

print(f"Dataset: {n_samples} sequences of length {seq_len}")
print(f"Vocab: {vocab_size}, Classes: {num_classes}")
Dataset: 150 sequences of length 8
Vocab: 20, Classes: 3

4. Initialize Model and Optimizer#

[9]:
params = init_transformer(
    vocab_size=vocab_size,
    d_model=d_model,
    num_layers=num_layers,
    num_heads=num_heads,
    num_classes=num_classes,
    max_len=seq_len,
    dim_ff=dim_ff,
)

opt_state = nb.nn.optim.adamw_init(params)

# Count parameters
from nabla import tree_leaves, Tensor
n_params = sum(p.numel() for p in tree_leaves(params) if isinstance(p, Tensor))
print(f"Model: {num_layers} layers, d_model={d_model}, heads={num_heads}")
print(f"Total trainable parameters: {n_params}")
Model: 2 layers, d_model=32, heads=4
Total trainable parameters: 18083

5. Training Loop#

[10]:
def loss_fn(params, tokens, targets):
    logits = transformer_forward(params, tokens, num_heads=num_heads)
    return nb.nn.functional.cross_entropy_loss(logits, targets)
[11]:
lr = 1e-3
num_epochs = 60

print(f"{'Epoch':<8} {'Loss':<12} {'Accuracy':<10}")
print("-" * 32)

for epoch in range(num_epochs):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, token_ids, labels)
    params, opt_state = nb.nn.optim.adamw_update(params, grads, opt_state, lr=lr)

    if (epoch + 1) % 10 == 0:
        logits = transformer_forward(params, token_ids, num_heads=num_heads)
        pred_classes = nb.argmax(logits, axis=-1)
        target_classes = nb.Tensor.from_dlpack(labels_np.astype(np.int64))
        correct = nb.equal(pred_classes, target_classes)
        accuracy = nb.mean(nb.cast(correct, nb.DType.float32)).item()
        print(f"{epoch + 1:<8} {loss.item():<12.4f} {accuracy:<10.2%}")
Epoch    Loss         Accuracy
--------------------------------
10       1.0926       44.67%
20       1.0410       48.67%
30       1.0012       49.33%
40       0.9542       53.33%
50       0.8927       62.00%
60       0.8154       64.67%

6. Compiled Training (Bonus)#

[12]:
@nb.compile
def compiled_step(params, opt_state, tokens, targets):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, tokens, targets)
    params, opt_state = nb.nn.optim.adamw_update(
        params, grads, opt_state, lr=1e-3
    )
    return params, opt_state, loss
[13]:
# Fresh parameters for compiled training
params2 = init_transformer(
    vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
    num_heads=num_heads, num_classes=num_classes, max_len=seq_len, dim_ff=dim_ff,
)
opt_state2 = nb.nn.optim.adamw_init(params2)

print(f"Compiled training:")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)

for step in range(30):
    params2, opt_state2, loss = compiled_step(params2, opt_state2, token_ids, labels)
    if (step + 1) % 10 == 0:
        print(f"{step + 1:<8} {loss.item():<12.4f}")
Compiled training:
Step     Loss
----------------------
10       1.1183
20       1.0890
30       1.0441

Summary#

The functional style decomposes the Transformer into pure functions:

Function

Role

init_transformer(...)

Creates the parameter pytree

transformer_forward(params, tokens, num_heads)

Pure forward pass

loss_fn(params, tokens, targets)

Computes scalar loss

value_and_grad(loss_fn, argnums=0)

Returns (loss, gradient pytree)

adamw_update(params, grads, ...)

Returns (new_params, new_opt_state)

No mutation, no hidden state — everything flows through function arguments.