Example 10: LoRA Fine-Tuning MVP

Example 10: LoRA Fine-Tuning MVP#

This example shows a minimal parameter-efficient fine-tuning workflow:

  • keep base weights frozen

  • train only LoRA adapters

  • save and reload a finetune checkpoint

[ ]:
from __future__ import annotations

import shutil
from pathlib import Path

import numpy as np

import nabla as nb

1. Synthetic Data Helper#

[ ]:
def make_regression_data(
    n_samples: int, in_dim: int, out_dim: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(123)
    x = rng.normal(size=(n_samples, in_dim)).astype(np.float32)

    w_base = rng.normal(size=(in_dim, out_dim)).astype(np.float32) * 0.5
    u = rng.normal(size=(in_dim, 4)).astype(np.float32)
    v = rng.normal(size=(4, out_dim)).astype(np.float32)
    delta = 0.35 * (u @ v)

    y = x @ (w_base + delta)
    return x, y.astype(np.float32), w_base.astype(np.float32)

2. Train Adapter and Validate Checkpoint#

[ ]:
def main() -> None:
    in_dim, out_dim = 64, 32
    rank = 8
    alpha = 16.0
    learning_rate = 2e-2
    steps = 120

    x_np, y_np, w_base_np = make_regression_data(
        n_samples=512, in_dim=in_dim, out_dim=out_dim
    )

    x = nb.Tensor.from_dlpack(x_np)
    y = nb.Tensor.from_dlpack(y_np)
    frozen_weight = nb.Tensor.from_dlpack(w_base_np)

    lora_params = nb.nn.finetune.init_lora_adapter(
        frozen_weight, rank=rank, init_std=0.01
    )
    opt_state = nb.nn.optim.adamw_init(lora_params)

    def loss_fn(adapter, batch_x, batch_y):
        pred = nb.nn.finetune.lora_linear(batch_x, frozen_weight, adapter, alpha=alpha)
        diff = pred - batch_y
        return nb.mean(diff * diff)

    def train_step(adapter, optimizer_state, batch_x, batch_y):
        loss, grads = nb.value_and_grad(loss_fn, argnums=0, realize=False)(
            adapter, batch_x, batch_y
        )
        new_adapter, new_state = nb.nn.optim.adamw_update(
            adapter,
            grads,
            optimizer_state,
            lr=learning_rate,
            weight_decay=0.0,
        )
        to_realize = [loss]
        to_realize.extend(t for t in nb.tree_leaves(grads) if isinstance(t, nb.Tensor))
        to_realize.extend(
            t for t in nb.tree_leaves(new_adapter) if isinstance(t, nb.Tensor)
        )
        to_realize.extend(
            t for t in nb.tree_leaves(new_state) if isinstance(t, nb.Tensor)
        )
        nb.realize_all(*to_realize)
        return loss, new_adapter, new_state

    initial_loss = float(loss_fn(lora_params, x, y).to_numpy())
    print(f"Initial loss: {initial_loss:.6f}")

    for step in range(steps):
        loss, lora_params, opt_state = train_step(lora_params, opt_state, x, y)
        if (step + 1) % 50 == 0:
            print(f"Step {step + 1:>3d}: loss={float(loss.to_numpy()):.6f}")

    final_loss = float(loss_fn(lora_params, x, y).to_numpy())
    print(f"Final loss:   {final_loss:.6f}")

    ckpt_dir = Path(".tmp_lora_ckpt")
    if ckpt_dir.exists():
        shutil.rmtree(ckpt_dir)

    nb.nn.finetune.save_finetune_checkpoint(
        ckpt_dir,
        lora_params=lora_params,
        optimizer_state=opt_state,
        metadata={"alpha": alpha, "rank": rank},
    )

    lora_template = nb.nn.finetune.init_lora_adapter(
        frozen_weight, rank=rank, init_std=0.01
    )
    opt_template = nb.nn.optim.adamw_init(lora_template)

    loaded_lora, loaded_opt, meta = nb.nn.finetune.load_finetune_checkpoint(
        ckpt_dir,
        lora_template=lora_template,
        optimizer_template=opt_template,
    )

    original_pred = nb.nn.finetune.lora_linear(
        x, frozen_weight, lora_params, alpha=alpha
    )
    loaded_pred = nb.nn.finetune.lora_linear(x, frozen_weight, loaded_lora, alpha=alpha)

    max_diff = np.max(np.abs(original_pred.to_numpy() - loaded_pred.to_numpy()))
    print(f"Checkpoint step: {loaded_opt['step'] if loaded_opt else 'N/A'}")
    print(f"Checkpoint max prediction diff: {max_diff:.8f}")
    print(f"Checkpoint metadata keys: {sorted(meta.get('user_metadata', {}).keys())}")

    if final_loss >= initial_loss:
        raise RuntimeError("LoRA training did not reduce loss.")
    if max_diff > 1e-5:
        raise RuntimeError(f"Checkpoint roundtrip mismatch too large: {max_diff}")

    print("✅ LoRA MVP finished successfully.")


if __name__ == "__main__":
    main()