Example 10: LoRA & QLoRA Fine-Tuning#
LoRA (Low-Rank Adaptation) trains a small adapter instead of full weights:
where \(A \in \mathbb{R}^{d \times r}\) and \(B \in \mathbb{R}^{r \times d}\) are the trainable low-rank factors, \(r \ll d\).
QLoRA goes further: the frozen weight \(W\) is quantized to NF4 (4 bits per parameter), saving ~75% memory while keeping full-precision adapters.
In this example we’ll:
Train a LoRA adapter on a synthetic regression task
Save and reload a finetune checkpoint
Quantize weights to NF4 and train QLoRA adapters
[1]:
from __future__ import annotations
import shutil
from pathlib import Path
import numpy as np
import nabla as nb
print("Nabla LoRA/QLoRA Fine-Tuning example")
Nabla LoRA/QLoRA Fine-Tuning example
1. Synthetic Regression Data#
We create a dataset where the true weight is \(W_{\text{base}} + \Delta\), where \(\Delta\) is a rank-4 perturbation. The adapter should learn to approximate \(\Delta\):
[2]:
def make_regression_data(n_samples, in_dim, out_dim, seed=123, delta_scale=0.35):
"""Generate synthetic linear regression data with a low-rank perturbation."""
rng = np.random.default_rng(seed)
x = rng.normal(size=(n_samples, in_dim)).astype(np.float32)
w_base = rng.normal(size=(in_dim, out_dim)).astype(np.float32) * 0.5
# Low-rank perturbation (rank 4)
u = rng.normal(size=(in_dim, 4)).astype(np.float32)
v = rng.normal(size=(4, out_dim)).astype(np.float32)
delta = delta_scale * (u @ v)
y = x @ (w_base + delta)
return x, y.astype(np.float32), w_base.astype(np.float32)
# Shared hyperparameters
IN_DIM, OUT_DIM = 64, 32
RANK = 8
ALPHA = 16.0
LR = 2e-2
STEPS = 120
x_np, y_np, w_base_np = make_regression_data(512, IN_DIM, OUT_DIM)
x = nb.Tensor.from_dlpack(x_np)
y = nb.Tensor.from_dlpack(y_np)
frozen_weight = nb.Tensor.from_dlpack(w_base_np)
print(f"Data: {x_np.shape[0]} samples, in={IN_DIM}, out={OUT_DIM}")
print(f"LoRA rank={RANK}, alpha={ALPHA}")
Data: 512 samples, in=64, out=32
LoRA rank=8, alpha=16.0
2. LoRA Training#
Initialize a LoRA adapter and train with AdamW. nb.nn.finetune.lora_linear computes \(xW + \frac{\alpha}{r} \cdot x A B\) in a single call:
[3]:
lora_params = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=RANK, init_std=0.01)
opt_state = nb.nn.optim.adamw_init(lora_params)
def lora_loss_fn(adapter, batch_x, batch_y):
"""MSE loss with LoRA-adapted linear layer."""
pred = nb.nn.finetune.lora_linear(batch_x, frozen_weight, adapter, alpha=ALPHA)
diff = pred - batch_y
return nb.mean(diff * diff)
def train_step(loss_fn, adapter, optimizer_state, batch_x, batch_y):
"""One training step: forward + backward + AdamW update."""
loss, grads = nb.value_and_grad(loss_fn, argnums=0, realize=False)(
adapter, batch_x, batch_y
)
new_adapter, new_state = nb.nn.optim.adamw_update(
adapter, grads, optimizer_state, lr=LR, weight_decay=0.0
)
# Batch-realize all tensors
to_realize = [loss]
to_realize.extend(t for t in nb.tree_leaves(grads) if isinstance(t, nb.Tensor))
to_realize.extend(t for t in nb.tree_leaves(new_adapter) if isinstance(t, nb.Tensor))
to_realize.extend(t for t in nb.tree_leaves(new_state) if isinstance(t, nb.Tensor))
nb.realize_all(*to_realize)
return loss, new_adapter, new_state
initial_loss = float(lora_loss_fn(lora_params, x, y).to_numpy())
print(f"Initial loss: {initial_loss:.6f}")
for step in range(STEPS):
loss, lora_params, opt_state = train_step(lora_loss_fn, lora_params, opt_state, x, y)
if (step + 1) % 50 == 0:
print(f"Step {step + 1:>3d}: loss = {float(loss.to_numpy()):.6f}")
final_loss = float(lora_loss_fn(lora_params, x, y).to_numpy())
print(f"Final loss: {final_loss:.6f}")
assert final_loss < initial_loss, "LoRA training did not reduce loss"
print("✅ LoRA adapter trained successfully")
Initial loss: 32.337105
Step 50: loss = 0.199841
Step 100: loss = 0.002403
Final loss: 0.000900
✅ LoRA adapter trained successfully
3. Save and Load Checkpoint#
Nabla provides checkpoint utilities for LoRA adapters. We save the trained adapter + optimizer state, reload them, and verify predictions match:
[4]:
ckpt_dir = Path(".tmp_lora_ckpt")
if ckpt_dir.exists():
shutil.rmtree(ckpt_dir)
# Save checkpoint
nb.nn.finetune.save_finetune_checkpoint(
ckpt_dir,
lora_params=lora_params,
optimizer_state=opt_state,
metadata={"alpha": ALPHA, "rank": RANK},
)
# Reload from checkpoint
lora_template = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=RANK, init_std=0.01)
opt_template = nb.nn.optim.adamw_init(lora_template)
loaded_lora, loaded_opt, meta = nb.nn.finetune.load_finetune_checkpoint(
ckpt_dir, lora_template=lora_template, optimizer_template=opt_template,
)
# Verify predictions match
original_pred = nb.nn.finetune.lora_linear(x, frozen_weight, lora_params, alpha=ALPHA)
loaded_pred = nb.nn.finetune.lora_linear(x, frozen_weight, loaded_lora, alpha=ALPHA)
max_diff = np.max(np.abs(original_pred.to_numpy() - loaded_pred.to_numpy()))
print(f"Checkpoint max prediction diff: {max_diff:.8f}")
assert max_diff < 1e-5, f"Checkpoint mismatch: {max_diff}"
print("✅ Checkpoint roundtrip verified")
# Cleanup
shutil.rmtree(ckpt_dir, ignore_errors=True)
Checkpoint max prediction diff: 0.00000000
✅ Checkpoint roundtrip verified
4. QLoRA: Quantized Base Weights#
QLoRA quantizes the frozen weight \(W\) to NF4 (4-bit Normal Float). During the forward pass, \(W\) is dequantized on-the-fly and combined with the LoRA adapter. This saves ~75% memory for the base weight while the adapter remains in full precision.
First, let’s check quantization quality:
[5]:
qweight = nb.nn.finetune.quantize_nf4(frozen_weight, block_size=64)
dense_recon = nb.nn.finetune.dequantize_nf4(qweight)
quant_err = float(
np.linalg.norm(dense_recon.to_numpy() - frozen_weight.to_numpy())
/ (np.linalg.norm(frozen_weight.to_numpy()) + 1e-8)
)
print(f"NF4 relative reconstruction error: {quant_err:.4f}")
print(f"Quantized weight type: {type(qweight).__name__}")
NF4 relative reconstruction error: 0.0912
Quantized weight type: dict
QLoRA Training#
Training is identical to LoRA except we use qlora_linear instead of lora_linear. The quantized weight is dequantized during the forward pass:
[6]:
# Fresh adapter for QLoRA
qlora_params = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=RANK, init_std=0.01)
qopt_state = nb.nn.optim.adamw_init(qlora_params)
def qlora_loss_fn(adapter, batch_x, batch_y):
"""MSE loss with QLoRA-adapted linear layer."""
pred = nb.nn.finetune.qlora_linear(
batch_x, qweight, adapter, alpha=ALPHA, compute_dtype=nb.DType.float32
)
diff = pred - batch_y
return nb.mean(diff * diff)
q_initial_loss = float(qlora_loss_fn(qlora_params, x, y).to_numpy())
print(f"QLoRA initial loss: {q_initial_loss:.6f}")
for step in range(STEPS):
loss, qlora_params, qopt_state = train_step(qlora_loss_fn, qlora_params, qopt_state, x, y)
if (step + 1) % 50 == 0:
print(f"Step {step + 1:>3d}: loss = {float(loss.to_numpy()):.6f}")
q_final_loss = float(qlora_loss_fn(qlora_params, x, y).to_numpy())
print(f"QLoRA final loss: {q_final_loss:.6f}")
assert q_final_loss < q_initial_loss, "QLoRA training did not reduce loss"
print("✅ QLoRA adapter trained successfully")
QLoRA initial loss: 32.562389
Step 50: loss = 0.301735
Step 100: loss = 0.095010
QLoRA final loss: 0.090136
✅ QLoRA adapter trained successfully
Key takeaways:
LoRA trains only \(2 \times r \times d\) parameters instead of \(d^2\)
QLoRA adds NF4 quantization for ~4x memory reduction on frozen weights
Both methods converge on our synthetic task with negligible quality loss
Checkpointing supports adapter + optimizer state + custom metadata