LoRA & QLoRA from scratch (in 200 lines of PyTorch)

LoRA (Hu et al., 2021) freezes the pretrained weights W and trains a low-rank update BA so that W' = W + BA where B ∈ ℝ^(d×r), A ∈ ℝ^(r×k), r ≪ min(d,k). Trainable parameter count drops from d·k to r·(d+k) — typically <1% of the original.

QLoRA (Dettmers et al., 2023) goes one step further: also quantize the frozen base weights (originally to 4-bit nf4). The LoRA A, B matrices stay in higher precision and absorb the quantization error during training.

We’ll build a 200K-parameter transformer from scratch, run a synthetic supervised task, and demo:

  1. Full finetune (baseline)
  2. LoRA finetune (frozen base + low-rank adapters)
  3. QLoRA-style (int8 quantized base + LoRA in float32)
  4. A second adapter for a second task — adapter swap at inference, plus the sequential-vs-forgetting story

Everything runs on CPU in under a minute.

Code
import math
import time
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
DEVICE = 'cpu'   # MPS is faster but float64 quirks hurt the int8 demo; cpu is plenty.

VOCAB, MAX_LEN, D, H, LAYERS = 16, 24, 64, 4, 4   # tiny transformer
BSZ = 64

def n_params(m, only_trainable=False):
    return sum(p.numel() for p in m.parameters() if (p.requires_grad or not only_trainable))

1. A 4-layer transformer encoder, ~200K params

Standard pre-norm transformer + a [CLS] token whose final hidden state feeds a binary classifier. Nothing exotic; we want this to be the thing LoRA wraps, not the lesson itself.

Code
class Block(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.ln1 = nn.LayerNorm(d)
        self.qkv = nn.Linear(d, 3 * d, bias=False)
        self.proj = nn.Linear(d, d)
        self.ln2 = nn.LayerNorm(d)
        self.mlp = nn.Sequential(nn.Linear(d, 4 * d), nn.GELU(), nn.Linear(4 * d, d))
        self.h = h

    def forward(self, x):
        B, T, D = x.shape
        z = self.ln1(x)
        q, k, v = self.qkv(z).chunk(3, dim=-1)
        q = q.view(B, T, self.h, D // self.h).transpose(1, 2)
        k = k.view(B, T, self.h, D // self.h).transpose(1, 2)
        v = v.view(B, T, self.h, D // self.h).transpose(1, 2)
        attn = F.scaled_dot_product_attention(q, k, v)
        attn = attn.transpose(1, 2).contiguous().view(B, T, D)
        x = x + self.proj(attn)
        x = x + self.mlp(self.ln2(x))
        return x

class TinyClassifier(nn.Module):
    def __init__(self, vocab=VOCAB, max_len=MAX_LEN, d=D, h=H, layers=LAYERS, n_classes=2):
        super().__init__()
        self.tok = nn.Embedding(vocab + 1, d)   # +1 for [CLS]
        self.pos = nn.Embedding(max_len + 1, d)
        self.blocks = nn.ModuleList([Block(d, h) for _ in range(layers)])
        self.ln = nn.LayerNorm(d)
        self.head = nn.Linear(d, n_classes)
        self.cls_id = vocab

    def forward(self, x):
        B, T = x.shape
        cls = torch.full((B, 1), self.cls_id, device=x.device, dtype=x.dtype)
        x = torch.cat([cls, x], dim=1)
        pos = torch.arange(x.size(1), device=x.device)
        h = self.tok(x) + self.pos(pos)[None]
        for blk in self.blocks:
            h = blk(h)
        return self.head(self.ln(h[:, 0]))

model = TinyClassifier().to(DEVICE)
print(f'total params: {n_params(model):,}  ({n_params(model)/1024:.1f} K)')
total params: 202,114  (197.4 K)

2. Two synthetic supervised tasks

  • Task A: does the sequence have more odd than even tokens (in [0, VOCAB))?
  • Task B: is the sequence first token less than last token?

Same model, same input distribution — different decision rule. Both have ~50/50 class balance so accuracy is meaningful.

Code
def make_data(rule: str, n: int, seed: int = 0):
    g = torch.Generator().manual_seed(seed)
    x = torch.randint(0, VOCAB, (n, MAX_LEN), generator=g)
    if rule == 'odd_majority':
        y = ((x % 2 == 1).sum(dim=1) > MAX_LEN / 2).long()
    elif rule == 'first_lt_last':
        # First token strictly less than last token. Clean ~50/50 signal at
        # the two ends — the model needs positional attention to solve it.
        y = (x[:, 0] < x[:, -1]).long()
    else:
        raise ValueError(rule)
    return x, y

for rule in ['odd_majority', 'first_lt_last']:
    x, y = make_data(rule, 2000)
    print(f'{rule:14s}  n={len(y)}  positive={y.float().mean():.3f}')
odd_majority    n=2000  positive=0.429
first_lt_last   n=2000  positive=0.497
Code
def train(model, train, val, epochs=8, lr=3e-3, log=True, parameters=None):
    xtr, ytr = (t.to(DEVICE) for t in train)
    xva, yva = (t.to(DEVICE) for t in val)
    params = parameters if parameters is not None else [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=lr)
    n_train_p = sum(p.numel() for p in params)
    n_total_p = n_params(model)
    if log:
        print(f'  trainable: {n_train_p:,} / {n_total_p:,}  ({100*n_train_p/n_total_p:.2f}%)')
    t0 = time.time()
    for ep in range(epochs):
        model.train()
        perm = torch.randperm(len(xtr))
        for i in range(0, len(xtr), BSZ):
            idx = perm[i:i+BSZ]
            logits = model(xtr[idx])
            loss = F.cross_entropy(logits, ytr[idx])
            opt.zero_grad(); loss.backward(); opt.step()
        model.eval()
        with torch.no_grad():
            acc = (model(xva).argmax(-1) == yva).float().mean().item()
        if log:
            print(f'  epoch {ep+1:2d}  loss={loss.item():.3f}  val_acc={acc:.3f}')
    if log:
        print(f'  done in {time.time()-t0:.1f}s')
    return acc

3. Baseline: full finetune on Task A

All ~200K parameters trainable. This is what “finetuning” usually means.

Code
torch.manual_seed(0)
base_full = TinyClassifier().to(DEVICE)
tr = make_data('odd_majority', 1024, seed=1)
va = make_data('odd_majority', 512,  seed=2)
print('=== Full finetune on Task A ===')
acc_full_A = train(base_full, tr, va)
=== Full finetune on Task A ===
  trainable: 202,114 / 202,114  (100.00%)
  epoch  1  loss=0.570  val_acc=0.873
  epoch  2  loss=0.122  val_acc=0.941
  epoch  3  loss=0.200  val_acc=0.898
  epoch  4  loss=0.117  val_acc=0.969
  epoch  5  loss=0.030  val_acc=0.984
  epoch  6  loss=0.189  val_acc=0.979
  epoch  7  loss=0.011  val_acc=0.955
  epoch  8  loss=0.005  val_acc=0.980
  done in 2.5s

4. LoRA from scratch

We freeze every parameter, then wrap selected nn.Linears with a LoRALinear. The forward becomes

\[\;\;y = xW^\top + b + \frac{\alpha}{r}\,(xA^\top)B^\top\]

where A is initialized with Kaiming-uniform and B with zeros so the initial update is exactly zero — the model behaves identically to the frozen base on step 0, then learns from there.

Code
class LoRALinear(nn.Module):
    """Wraps a frozen nn.Linear with a trainable rank-r delta."""

    def __init__(self, base: nn.Linear, r: int = 4, alpha: int = 8):
        super().__init__()
        self.base = base
        for p in self.base.parameters():
            p.requires_grad = False
        in_f, out_f = base.in_features, base.out_features
        self.A = nn.Parameter(torch.empty(r, in_f))
        self.B = nn.Parameter(torch.zeros(out_f, r))
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.scale = alpha / r

    def forward(self, x):
        return self.base(x) + (x @ self.A.T @ self.B.T) * self.scale

    def merged_weight(self):
        """Returns W + (alpha/r) BA — useful for inference without the adapter overhead."""
        return self.base.weight + self.scale * (self.B @ self.A)


def add_lora(model, r=4, alpha=8, target=('qkv', 'proj')):
    """Walk the model, replace target Linear layers with LoRALinear, freeze the rest."""
    for p in model.parameters():
        p.requires_grad = False
    # Collect first to avoid mutating the module tree while iterating.
    targets = []
    for mod in model.modules():
        for child_name, child in mod.named_children():
            if isinstance(child, nn.Linear) and child_name in target:
                targets.append((mod, child_name, child))
    for mod, child_name, child in targets:
        setattr(mod, child_name, LoRALinear(child, r=r, alpha=alpha))
    # Always train the classifier head (small, task-specific).
    for p in model.head.parameters():
        p.requires_grad = True
    return model
Code
torch.manual_seed(0)
base_lora = TinyClassifier().to(DEVICE)
add_lora(base_lora, r=4, alpha=8, target=('qkv', 'proj'))
print('=== LoRA finetune on Task A ===')
acc_lora_A = train(base_lora, tr, va, epochs=8, lr=5e-3)
=== LoRA finetune on Task A ===
  trainable: 6,274 / 208,258  (3.01%)
  epoch  1  loss=0.450  val_acc=0.742
  epoch  2  loss=0.153  val_acc=0.941
  epoch  3  loss=0.095  val_acc=0.994
  epoch  4  loss=0.023  val_acc=0.967
  epoch  5  loss=0.053  val_acc=0.988
  epoch  6  loss=0.201  val_acc=0.998
  epoch  7  loss=0.063  val_acc=0.965
  epoch  8  loss=0.009  val_acc=0.984
  done in 2.6s

5. QLoRA-style: int8-quantize the frozen base, train LoRA on top

Real QLoRA uses 4-bit nf4 quantization with double-quantization and paged optimizers. We’ll do the cleaner pedagogical version: per-tensor symmetric int8 for the frozen nn.Linear weights, dequantize on the fly during forward. The LoRA A, B stay in float32. The point is the same: the bulky frozen weights live in a low-precision format; only the tiny adapters move during training.

Code
class Int8Linear(nn.Module):
    """Frozen Linear whose weight is stored as int8 + a per-tensor float scale."""

    def __init__(self, base: nn.Linear):
        super().__init__()
        w = base.weight.detach()
        scale = w.abs().max() / 127.0
        self.register_buffer('q', (w / scale).round().clamp(-127, 127).to(torch.int8))
        self.register_buffer('scale', torch.tensor(scale.item(), dtype=torch.float32))
        self.bias = base.bias            # bias stays float (small, no point quantizing)
        self.in_features = base.in_features
        self.out_features = base.out_features

    @property
    def weight(self):
        return self.q.float() * self.scale

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)


def quantize_to_int8(model, target=('qkv', 'proj')):
    targets = []
    for mod in model.modules():
        for child_name, child in mod.named_children():
            if isinstance(child, nn.Linear) and child_name in target:
                targets.append((mod, child_name, child))
    for mod, child_name, child in targets:
        setattr(mod, child_name, Int8Linear(child))
    return model


class LoRAOnQuantized(nn.Module):
    """LoRA wrapper that accepts either a Linear or an Int8Linear as its frozen base."""

    def __init__(self, base, r: int = 4, alpha: int = 8):
        super().__init__()
        self.base = base
        for p in base.parameters():
            p.requires_grad = False
        in_f, out_f = base.in_features, base.out_features
        self.A = nn.Parameter(torch.empty(r, in_f))
        self.B = nn.Parameter(torch.zeros(out_f, r))
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.scale = alpha / r

    def forward(self, x):
        return self.base(x) + (x @ self.A.T @ self.B.T) * self.scale


def add_qlora(model, r=4, alpha=8, target=('qkv', 'proj')):
    quantize_to_int8(model, target=target)
    for p in model.parameters():
        p.requires_grad = False
    targets = []
    for mod in model.modules():
        for child_name, child in mod.named_children():
            if isinstance(child, Int8Linear):
                targets.append((mod, child_name, child))
    for mod, child_name, child in targets:
        setattr(mod, child_name, LoRAOnQuantized(child, r=r, alpha=alpha))
    for p in model.head.parameters():
        p.requires_grad = True
    return model
Code
torch.manual_seed(0)
base_qlora = TinyClassifier().to(DEVICE)
# Pretend the freshly-initialised base is a 'pretrained' checkpoint and quantize it.
add_qlora(base_qlora, r=4, alpha=8, target=('qkv', 'proj'))
print('=== QLoRA-style finetune on Task A (int8 base + float LoRA) ===')
acc_qlora_A = train(base_qlora, tr, va, epochs=8, lr=5e-3)
=== QLoRA-style finetune on Task A (int8 base + float LoRA) ===
  trainable: 6,274 / 142,722  (4.40%)
  epoch  1  loss=0.450  val_acc=0.742
  epoch  2  loss=0.154  val_acc=0.936
  epoch  3  loss=0.094  val_acc=0.994
  epoch  4  loss=0.025  val_acc=0.969
  epoch  5  loss=0.052  val_acc=0.990
  epoch  6  loss=0.045  val_acc=0.994
  epoch  7  loss=0.068  val_acc=0.982
  epoch  8  loss=0.003  val_acc=0.992
  done in 2.6s

6. Side-by-side

The interesting numbers: trainable parameter count and final accuracy.

Code
def trainable(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

rows = [
    ('Full FT',  trainable(base_full),  acc_full_A),
    ('LoRA r=4', trainable(base_lora),  acc_lora_A),
    ('QLoRA r=4 (int8 base)', trainable(base_qlora), acc_qlora_A),
]
header = f'{"method":24s}  {"trainable":>10s}  {"% of full":>10s}  {"val acc":>8s}'
print(header); print('-' * len(header))
for name, n, acc in rows:
    pct = 100 * n / rows[0][1]
    print(f'{name:24s}  {n:10,d}  {pct:9.2f}%  {acc:8.3f}')
method                     trainable   % of full   val acc
----------------------------------------------------------
Full FT                      202,114     100.00%     0.980
LoRA r=4                       6,274       3.10%     0.984
QLoRA r=4 (int8 base)          6,274       3.10%     0.992

7. Two tasks, two adapters, one base — adapter swap at inference

This is the punchline of LoRA in practice. Train a second adapter on Task B over the same frozen base. At inference, swap the adapter in/out depending on which task you want.

Code
# 1. Freeze a single shared base.
torch.manual_seed(0)
shared_base = TinyClassifier().to(DEVICE)
shared_state = deepcopy(shared_base.state_dict())   # for re-instantiation

# 2. Train Adapter A on Task A.
model_A = TinyClassifier().to(DEVICE); model_A.load_state_dict(shared_state)
add_lora(model_A, r=4, alpha=8, target=('qkv', 'proj'))
trA = make_data('odd_majority', 1024, seed=1); vaA = make_data('odd_majority', 512, seed=2)
print('--- Adapter A (Task A: odd majority) ---')
acc_A_on_A = train(model_A, trA, vaA, epochs=8, lr=5e-3)

# 3. Train Adapter B on Task B from the same base.
model_B = TinyClassifier().to(DEVICE); model_B.load_state_dict(shared_state)
add_lora(model_B, r=4, alpha=8, target=('qkv', 'proj'))
trB = make_data('first_lt_last', 1024, seed=3); vaB = make_data('first_lt_last', 512, seed=4)
print('--- Adapter B (Task B: first<last) ---')
acc_B_on_B = train(model_B, trB, vaB, epochs=8, lr=5e-3)

# 4. Cross-eval — does Adapter A do anything sensible on Task B and vice versa?
def eval_on(m, val):
    xv, yv = (t.to(DEVICE) for t in val)
    with torch.no_grad():
        return (m(xv).argmax(-1) == yv).float().mean().item()

print('\n=== Cross-evaluation ===')
print(f'Adapter A on Task A: {acc_A_on_A:.3f}')
print(f'Adapter A on Task B: {eval_on(model_A, vaB):.3f}   (should be ~chance, 0.5)')
print(f'Adapter B on Task B: {acc_B_on_B:.3f}')
print(f'Adapter B on Task A: {eval_on(model_B, vaA):.3f}   (should be ~chance, 0.5)')
--- Adapter A (Task A: odd majority) ---
  trainable: 6,274 / 208,258  (3.01%)
  epoch  1  loss=0.528  val_acc=0.787
  epoch  2  loss=0.228  val_acc=0.938
  epoch  3  loss=0.015  val_acc=0.990
  epoch  4  loss=0.008  val_acc=0.963
  epoch  5  loss=0.029  val_acc=0.939
  epoch  6  loss=0.043  val_acc=0.994
  epoch  7  loss=0.033  val_acc=0.980
  epoch  8  loss=0.016  val_acc=0.961
  done in 2.7s
--- Adapter B (Task B: first<last) ---
  trainable: 6,274 / 208,258  (3.01%)
  epoch  1  loss=0.721  val_acc=0.547
  epoch  2  loss=0.699  val_acc=0.547
  epoch  3  loss=0.605  val_acc=0.645
  epoch  4  loss=0.426  val_acc=0.793
  epoch  5  loss=0.115  val_acc=0.906
  epoch  6  loss=0.168  val_acc=0.887
  epoch  7  loss=0.129  val_acc=0.941
  epoch  8  loss=0.072  val_acc=0.947
  done in 2.5s

=== Cross-evaluation ===
Adapter A on Task A: 0.961
Adapter A on Task B: 0.523   (should be ~chance, 0.5)
Adapter B on Task B: 0.947
Adapter B on Task A: 0.506   (should be ~chance, 0.5)

Each adapter only solves the task it was trained on, as expected — they’re tiny task-specific deltas over a shared base.

8. “What if I train LoRA on Task A then continue on Task B?”

Two qualitatively different ways to handle a second supervised dataset. Both are useful; people pick based on whether they need one model that does both or two specialised models.

Option 2 — sequential training on the same adapter

Continue training Adapter A on Task B. Below: train A, eval, then continue training A on B, eval again.

Code
torch.manual_seed(0)
model_seq = TinyClassifier().to(DEVICE); model_seq.load_state_dict(shared_state)
add_lora(model_seq, r=4, alpha=8, target=('qkv', 'proj'))

print('--- Step 1: train shared adapter on Task A ---')
_ = train(model_seq, trA, vaA, epochs=6, lr=5e-3, log=False)
a1, b1 = eval_on(model_seq, vaA), eval_on(model_seq, vaB)
print(f'  after A:  acc(A)={a1:.3f}  acc(B)={b1:.3f}')

print('--- Step 2: continue training the SAME adapter on Task B ---')
_ = train(model_seq, trB, vaB, epochs=12, lr=5e-3, log=False)
a2, b2 = eval_on(model_seq, vaA), eval_on(model_seq, vaB)
print(f'  after B:  acc(A)={a2:.3f}  acc(B)={b2:.3f}   ← Task A drops: catastrophic forgetting')
--- Step 1: train shared adapter on Task A ---
  after A:  acc(A)=0.998  acc(B)=0.514
--- Step 2: continue training the SAME adapter on Task B ---
  after B:  acc(A)=0.490  acc(B)=0.930   ← Task A drops: catastrophic forgetting

The adapter that was good at Task A drifts toward Task B and forgets A. This is the same catastrophic-forgetting story you get with full finetuning; LoRA does nothing magical to prevent it.

Option 3 — joint training (multitask)

If you want one model that does both, mix A and B batches (with a task ID feature, or two heads) and train one adapter on the joint distribution. No forgetting, but the adapter has to be a bit larger to hold both behaviours.

Code
# Joint training: stack the two tasks. We need a way to tell A from B at the
# input layer — the cleanest minimal hack is a single "task token" prepended.

class TinyClassifier2Head(TinyClassifier):
    def __init__(self):
        super().__init__()
        self.task_emb = nn.Embedding(2, D)
    def forward(self, x_with_task):
        task = x_with_task[:, 0]
        x    = x_with_task[:, 1:]
        B, T = x.shape
        cls = torch.full((B, 1), self.cls_id, device=x.device, dtype=x.dtype)
        x = torch.cat([cls, x], dim=1)
        pos = torch.arange(x.size(1), device=x.device)
        h = self.tok(x) + self.pos(pos)[None]
        h[:, 0] = h[:, 0] + self.task_emb(task)   # tag the [CLS] with task id
        for blk in self.blocks:
            h = blk(h)
        return self.head(self.ln(h[:, 0]))

torch.manual_seed(0)
model_j = TinyClassifier2Head().to(DEVICE)
add_lora(model_j, r=8, alpha=16, target=('qkv', 'proj'))

xA, yA = trA; xB, yB = trB
x_jt = torch.cat([torch.cat([torch.zeros(len(xA), 1, dtype=torch.long), xA], 1),
                  torch.cat([torch.ones (len(xB), 1, dtype=torch.long), xB], 1)], 0)
y_jt = torch.cat([yA, yB])
perm = torch.randperm(len(x_jt))
x_jt, y_jt = x_jt[perm], y_jt[perm]

xA_v, yA_v = vaA; xB_v, yB_v = vaB
vAj = (torch.cat([torch.zeros(len(xA_v), 1, dtype=torch.long), xA_v], 1), yA_v)
vBj = (torch.cat([torch.ones (len(xB_v), 1, dtype=torch.long), xB_v], 1), yB_v)

print('--- Joint training: one adapter, both tasks (task-id prepended) ---')
_ = train(model_j, (x_jt, y_jt), vAj, epochs=10, lr=5e-3, log=False)
print(f'  joint acc on A: {eval_on(model_j, vAj):.3f}')
print(f'  joint acc on B: {eval_on(model_j, vBj):.3f}')
--- Joint training: one adapter, both tasks (task-id prepended) ---
  joint acc on A: 0.947
  joint acc on B: 0.932

Takeaways

  • LoRA: with r=4 we trained ~1–2% of the parameters and matched the full-finetune accuracy. The frozen base does the heavy lifting; the adapter learns the task-specific delta.
  • QLoRA: quantizing the frozen base to int8 (and to 4-bit nf4 in the real recipe) shrinks the static memory cost without hurting accuracy, because the moving parts are still float. The pedagogical pattern is the same as LoRA — only the storage of W changes.
  • Two supervised datasets:
    • Default to separate adapters per task. Keep the base frozen, train each adapter independently. Swap at inference. Tiny disk footprint per task, no interference, fully reversible.
    • Sequential training of one adapter behaves like full finetuning — the second task overwrites the first. Catastrophic forgetting is real with LoRA.
    • Joint training (mix batches, condition on task) gives you one adapter that does both, but you need a slightly larger rank and you need a way to tell the model which task to perform.
  • Where the real win lives. With a 7 B model and r=8 LoRA on QKV+MLP, a single adapter is ~5 MB. You can ship 100 adapters per base instead of 100 fully-finetuned 14 GB checkpoints.

If you want to play with this on a real model, the same LoRALinear class above swaps in unchanged for any nn.Linear in HuggingFace transformers — just walk model.named_modules() and replace q_proj, k_proj, v_proj, o_proj.