Example 11: QLoRA Fine-Tuning MVP#
This example mirrors LoRA fine-tuning with quantized base weights:
NF4 quantization of frozen weights
LoRA adapter training on quantized weights
quick quality checks (loss drop + quantization error)
[ ]:
from __future__ import annotations
import numpy as np
import nabla as nb
1. Synthetic Data Helper#
[ ]:
def make_regression_data(
n_samples: int, in_dim: int, out_dim: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
rng = np.random.default_rng(321)
x = rng.normal(size=(n_samples, in_dim)).astype(np.float32)
w_base = rng.normal(size=(in_dim, out_dim)).astype(np.float32) * 0.4
u = rng.normal(size=(in_dim, 4)).astype(np.float32)
v = rng.normal(size=(4, out_dim)).astype(np.float32)
delta = 0.30 * (u @ v)
y = x @ (w_base + delta)
return x, y.astype(np.float32), w_base.astype(np.float32)
2. Train QLoRA Adapter#
[ ]:
def main() -> None:
in_dim, out_dim = 64, 32
rank = 8
alpha = 16.0
learning_rate = 2e-2
steps = 120
x_np, y_np, w_base_np = make_regression_data(
n_samples=512, in_dim=in_dim, out_dim=out_dim
)
x = nb.Tensor.from_dlpack(x_np)
y = nb.Tensor.from_dlpack(y_np)
frozen_weight = nb.Tensor.from_dlpack(w_base_np)
qweight = nb.nn.finetune.quantize_nf4(frozen_weight, block_size=64)
dense_recon = nb.nn.finetune.dequantize_nf4(qweight)
quant_rel_err = float(
np.linalg.norm(dense_recon.to_numpy() - frozen_weight.to_numpy())
/ (np.linalg.norm(frozen_weight.to_numpy()) + 1e-8)
)
print(f"NF4 relative reconstruction error: {quant_rel_err:.4f}")
lora_params = nb.nn.finetune.init_lora_adapter(
frozen_weight, rank=rank, init_std=0.01
)
opt_state = nb.nn.optim.adamw_init(lora_params)
def loss_fn(adapter, batch_x, batch_y):
pred = nb.nn.finetune.qlora_linear(
batch_x,
qweight,
adapter,
alpha=alpha,
compute_dtype=nb.DType.float32,
)
diff = pred - batch_y
return nb.mean(diff * diff)
def train_step(adapter, optimizer_state, batch_x, batch_y):
loss, grads = nb.value_and_grad(loss_fn, argnums=0, realize=False)(
adapter, batch_x, batch_y
)
new_adapter, new_state = nb.nn.optim.adamw_update(
adapter,
grads,
optimizer_state,
lr=learning_rate,
weight_decay=0.0,
)
to_realize = [loss]
to_realize.extend(t for t in nb.tree_leaves(grads) if isinstance(t, nb.Tensor))
to_realize.extend(
t for t in nb.tree_leaves(new_adapter) if isinstance(t, nb.Tensor)
)
to_realize.extend(
t for t in nb.tree_leaves(new_state) if isinstance(t, nb.Tensor)
)
nb.realize_all(*to_realize)
return loss, new_adapter, new_state
initial_loss = float(loss_fn(lora_params, x, y).to_numpy())
print(f"Initial loss: {initial_loss:.6f}")
for step in range(steps):
loss, lora_params, opt_state = train_step(lora_params, opt_state, x, y)
if (step + 1) % 50 == 0:
print(f"Step {step + 1:>3d}: loss={float(loss.to_numpy()):.6f}")
final_loss = float(loss_fn(lora_params, x, y).to_numpy())
print(f"Final loss: {final_loss:.6f}")
if final_loss >= initial_loss:
raise RuntimeError("QLoRA training did not reduce loss.")
print("✅ QLoRA MVP finished successfully.")
if __name__ == "__main__":
main()