Example 10: LoRA & QLoRA Fine-Tuning#

LoRA (Low-Rank Adaptation) trains a small adapter instead of full weights:

\[W_{\text{effective}} = W_{\text{frozen}} + \frac{\alpha}{r} \cdot B \cdot A\]

where \(A \in \mathbb{R}^{d \times r}\) and \(B \in \mathbb{R}^{r \times d}\) are the trainable low-rank factors, \(r \ll d\).

QLoRA goes further: the frozen weight \(W\) is quantized to NF4 (4 bits per parameter), saving ~75% memory while keeping full-precision adapters.

In this example we’ll:

  1. Train a LoRA adapter on a synthetic regression task

  2. Save and reload a finetune checkpoint

  3. Quantize weights to NF4 and train QLoRA adapters

[1]:
from __future__ import annotations

import shutil
from pathlib import Path

import numpy as np

import nabla as nb

print("Nabla LoRA/QLoRA Fine-Tuning example")
Nabla LoRA/QLoRA Fine-Tuning example

1. Synthetic Regression Data#

We create a dataset where the true weight is \(W_{\text{base}} + \Delta\), where \(\Delta\) is a rank-4 perturbation. The adapter should learn to approximate \(\Delta\):

[2]:
def make_regression_data(n_samples, in_dim, out_dim, seed=123, delta_scale=0.35):
    """Generate synthetic linear regression data with a low-rank perturbation."""
    rng = np.random.default_rng(seed)
    x = rng.normal(size=(n_samples, in_dim)).astype(np.float32)
    w_base = rng.normal(size=(in_dim, out_dim)).astype(np.float32) * 0.5
    # Low-rank perturbation (rank 4)
    u = rng.normal(size=(in_dim, 4)).astype(np.float32)
    v = rng.normal(size=(4, out_dim)).astype(np.float32)
    delta = delta_scale * (u @ v)
    y = x @ (w_base + delta)
    return x, y.astype(np.float32), w_base.astype(np.float32)

# Shared hyperparameters
IN_DIM, OUT_DIM = 64, 32
RANK = 8
ALPHA = 16.0
LR = 2e-2
STEPS = 120

x_np, y_np, w_base_np = make_regression_data(512, IN_DIM, OUT_DIM)
x = nb.Tensor.from_dlpack(x_np)
y = nb.Tensor.from_dlpack(y_np)
frozen_weight = nb.Tensor.from_dlpack(w_base_np)

print(f"Data: {x_np.shape[0]} samples, in={IN_DIM}, out={OUT_DIM}")
print(f"LoRA rank={RANK}, alpha={ALPHA}")
Data: 512 samples, in=64, out=32
LoRA rank=8, alpha=16.0

2. LoRA Training#

Initialize a LoRA adapter and train with AdamW. nb.nn.finetune.lora_linear computes \(xW + \frac{\alpha}{r} \cdot x A B\) in a single call:

[3]:
lora_params = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=RANK, init_std=0.01)
opt_state = nb.nn.optim.adamw_init(lora_params)

def lora_loss_fn(adapter, batch_x, batch_y):
    """MSE loss with LoRA-adapted linear layer."""
    pred = nb.nn.finetune.lora_linear(batch_x, frozen_weight, adapter, alpha=ALPHA)
    diff = pred - batch_y
    return nb.mean(diff * diff)

def train_step(loss_fn, adapter, optimizer_state, batch_x, batch_y):
    """One training step: forward + backward + AdamW update."""
    loss, grads = nb.value_and_grad(loss_fn, argnums=0, realize=False)(
        adapter, batch_x, batch_y
    )
    new_adapter, new_state = nb.nn.optim.adamw_update(
        adapter, grads, optimizer_state, lr=LR, weight_decay=0.0
    )
    # Batch-realize all tensors
    to_realize = [loss]
    to_realize.extend(t for t in nb.tree_leaves(grads) if isinstance(t, nb.Tensor))
    to_realize.extend(t for t in nb.tree_leaves(new_adapter) if isinstance(t, nb.Tensor))
    to_realize.extend(t for t in nb.tree_leaves(new_state) if isinstance(t, nb.Tensor))
    nb.realize_all(*to_realize)
    return loss, new_adapter, new_state

initial_loss = float(lora_loss_fn(lora_params, x, y).to_numpy())
print(f"Initial loss: {initial_loss:.6f}")

for step in range(STEPS):
    loss, lora_params, opt_state = train_step(lora_loss_fn, lora_params, opt_state, x, y)
    if (step + 1) % 50 == 0:
        print(f"Step {step + 1:>3d}: loss = {float(loss.to_numpy()):.6f}")

final_loss = float(lora_loss_fn(lora_params, x, y).to_numpy())
print(f"Final loss:   {final_loss:.6f}")
assert final_loss < initial_loss, "LoRA training did not reduce loss"
print("✅ LoRA adapter trained successfully")
Initial loss: 32.337105
Step  50: loss = 0.199841
Step 100: loss = 0.002403
Final loss:   0.000900
✅ LoRA adapter trained successfully

3. Save and Load Checkpoint#

Nabla provides checkpoint utilities for LoRA adapters. We save the trained adapter + optimizer state, reload them, and verify predictions match:

[4]:
ckpt_dir = Path(".tmp_lora_ckpt")
if ckpt_dir.exists():
    shutil.rmtree(ckpt_dir)

# Save checkpoint
nb.nn.finetune.save_finetune_checkpoint(
    ckpt_dir,
    lora_params=lora_params,
    optimizer_state=opt_state,
    metadata={"alpha": ALPHA, "rank": RANK},
)

# Reload from checkpoint
lora_template = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=RANK, init_std=0.01)
opt_template = nb.nn.optim.adamw_init(lora_template)
loaded_lora, loaded_opt, meta = nb.nn.finetune.load_finetune_checkpoint(
    ckpt_dir, lora_template=lora_template, optimizer_template=opt_template,
)

# Verify predictions match
original_pred = nb.nn.finetune.lora_linear(x, frozen_weight, lora_params, alpha=ALPHA)
loaded_pred = nb.nn.finetune.lora_linear(x, frozen_weight, loaded_lora, alpha=ALPHA)
max_diff = np.max(np.abs(original_pred.to_numpy() - loaded_pred.to_numpy()))

print(f"Checkpoint max prediction diff: {max_diff:.8f}")
assert max_diff < 1e-5, f"Checkpoint mismatch: {max_diff}"
print("✅ Checkpoint roundtrip verified")

# Cleanup
shutil.rmtree(ckpt_dir, ignore_errors=True)
Checkpoint max prediction diff: 0.00000000
✅ Checkpoint roundtrip verified

4. QLoRA: Quantized Base Weights#

QLoRA quantizes the frozen weight \(W\) to NF4 (4-bit Normal Float). During the forward pass, \(W\) is dequantized on-the-fly and combined with the LoRA adapter. This saves ~75% memory for the base weight while the adapter remains in full precision.

First, let’s check quantization quality:

[5]:
qweight = nb.nn.finetune.quantize_nf4(frozen_weight, block_size=64)
dense_recon = nb.nn.finetune.dequantize_nf4(qweight)
quant_err = float(
    np.linalg.norm(dense_recon.to_numpy() - frozen_weight.to_numpy())
    / (np.linalg.norm(frozen_weight.to_numpy()) + 1e-8)
)
print(f"NF4 relative reconstruction error: {quant_err:.4f}")
print(f"Quantized weight type: {type(qweight).__name__}")
NF4 relative reconstruction error: 0.0912
Quantized weight type: dict

QLoRA Training#

Training is identical to LoRA except we use qlora_linear instead of lora_linear. The quantized weight is dequantized during the forward pass:

[6]:
# Fresh adapter for QLoRA
qlora_params = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=RANK, init_std=0.01)
qopt_state = nb.nn.optim.adamw_init(qlora_params)

def qlora_loss_fn(adapter, batch_x, batch_y):
    """MSE loss with QLoRA-adapted linear layer."""
    pred = nb.nn.finetune.qlora_linear(
        batch_x, qweight, adapter, alpha=ALPHA, compute_dtype=nb.DType.float32
    )
    diff = pred - batch_y
    return nb.mean(diff * diff)

q_initial_loss = float(qlora_loss_fn(qlora_params, x, y).to_numpy())
print(f"QLoRA initial loss: {q_initial_loss:.6f}")

for step in range(STEPS):
    loss, qlora_params, qopt_state = train_step(qlora_loss_fn, qlora_params, qopt_state, x, y)
    if (step + 1) % 50 == 0:
        print(f"Step {step + 1:>3d}: loss = {float(loss.to_numpy()):.6f}")

q_final_loss = float(qlora_loss_fn(qlora_params, x, y).to_numpy())
print(f"QLoRA final loss:   {q_final_loss:.6f}")
assert q_final_loss < q_initial_loss, "QLoRA training did not reduce loss"
print("✅ QLoRA adapter trained successfully")
QLoRA initial loss: 32.562389
Step  50: loss = 0.301735
Step 100: loss = 0.095010
QLoRA final loss:   0.090136
✅ QLoRA adapter trained successfully

Key takeaways:

  • LoRA trains only \(2 \times r \times d\) parameters instead of \(d^2\)

  • QLoRA adds NF4 quantization for ~4x memory reduction on frozen weights

  • Both methods converge on our synthetic task with negligible quality loss

  • Checkpointing supports adapter + optimizer state + custom metadata