LoRA (Hu et al., 2021) freezes the pretrained weights W and trains a low-rank update BA so that W' = W + BA where B ∈ ℝ^(d×r), A ∈ ℝ^(r×k), r ≪ min(d,k). Trainable parameter count drops from d·k to r·(d+k) — typically <1% of the original.
QLoRA (Dettmers et al., 2023) goes one step further: also quantize the frozen base weights (originally to 4-bit nf4). The LoRA A, B matrices stay in higher precision and absorb the quantization error during training.
We’ll build a 200K-parameter transformer from scratch, run a synthetic supervised task, and demo:
Full finetune (baseline)
LoRA finetune (frozen base + low-rank adapters)
QLoRA-style (int8 quantized base + LoRA in float32)
A second adapter for a second task — adapter swap at inference, plus the sequential-vs-forgetting story
Everything runs on CPU in under a minute.
Code
import mathimport timefrom copy import deepcopyimport torchimport torch.nn as nnimport torch.nn.functional as Ftorch.manual_seed(0)DEVICE ='cpu'# MPS is faster but float64 quirks hurt the int8 demo; cpu is plenty.VOCAB, MAX_LEN, D, H, LAYERS =16, 24, 64, 4, 4# tiny transformerBSZ =64def n_params(m, only_trainable=False):returnsum(p.numel() for p in m.parameters() if (p.requires_grad ornot only_trainable))
1. A 4-layer transformer encoder, ~200K params
Standard pre-norm transformer + a [CLS] token whose final hidden state feeds a binary classifier. Nothing exotic; we want this to be the thing LoRA wraps, not the lesson itself.
Code
class Block(nn.Module):def__init__(self, d, h):super().__init__()self.ln1 = nn.LayerNorm(d)self.qkv = nn.Linear(d, 3* d, bias=False)self.proj = nn.Linear(d, d)self.ln2 = nn.LayerNorm(d)self.mlp = nn.Sequential(nn.Linear(d, 4* d), nn.GELU(), nn.Linear(4* d, d))self.h = hdef forward(self, x): B, T, D = x.shape z =self.ln1(x) q, k, v =self.qkv(z).chunk(3, dim=-1) q = q.view(B, T, self.h, D //self.h).transpose(1, 2) k = k.view(B, T, self.h, D //self.h).transpose(1, 2) v = v.view(B, T, self.h, D //self.h).transpose(1, 2) attn = F.scaled_dot_product_attention(q, k, v) attn = attn.transpose(1, 2).contiguous().view(B, T, D) x = x +self.proj(attn) x = x +self.mlp(self.ln2(x))return xclass TinyClassifier(nn.Module):def__init__(self, vocab=VOCAB, max_len=MAX_LEN, d=D, h=H, layers=LAYERS, n_classes=2):super().__init__()self.tok = nn.Embedding(vocab +1, d) # +1 for [CLS]self.pos = nn.Embedding(max_len +1, d)self.blocks = nn.ModuleList([Block(d, h) for _ inrange(layers)])self.ln = nn.LayerNorm(d)self.head = nn.Linear(d, n_classes)self.cls_id = vocabdef forward(self, x): B, T = x.shape cls = torch.full((B, 1), self.cls_id, device=x.device, dtype=x.dtype) x = torch.cat([cls, x], dim=1) pos = torch.arange(x.size(1), device=x.device) h =self.tok(x) +self.pos(pos)[None]for blk inself.blocks: h = blk(h)returnself.head(self.ln(h[:, 0]))model = TinyClassifier().to(DEVICE)print(f'total params: {n_params(model):,} ({n_params(model)/1024:.1f} K)')
total params: 202,114 (197.4 K)
2. Two synthetic supervised tasks
Task A: does the sequence have more odd than even tokens (in [0, VOCAB))?
Task B: is the sequence first token less than last token?
Same model, same input distribution — different decision rule. Both have ~50/50 class balance so accuracy is meaningful.
Code
def make_data(rule: str, n: int, seed: int=0): g = torch.Generator().manual_seed(seed) x = torch.randint(0, VOCAB, (n, MAX_LEN), generator=g)if rule =='odd_majority': y = ((x %2==1).sum(dim=1) > MAX_LEN /2).long()elif rule =='first_lt_last':# First token strictly less than last token. Clean ~50/50 signal at# the two ends — the model needs positional attention to solve it. y = (x[:, 0] < x[:, -1]).long()else:raiseValueError(rule)return x, yfor rule in ['odd_majority', 'first_lt_last']: x, y = make_data(rule, 2000)print(f'{rule:14s} n={len(y)} positive={y.float().mean():.3f}')
We freeze every parameter, then wrap selected nn.Linears with a LoRALinear. The forward becomes
\[\;\;y = xW^\top + b + \frac{\alpha}{r}\,(xA^\top)B^\top\]
where A is initialized with Kaiming-uniform and B with zeros so the initial update is exactly zero — the model behaves identically to the frozen base on step 0, then learns from there.
Code
class LoRALinear(nn.Module):"""Wraps a frozen nn.Linear with a trainable rank-r delta."""def__init__(self, base: nn.Linear, r: int=4, alpha: int=8):super().__init__()self.base = basefor p inself.base.parameters(): p.requires_grad =False in_f, out_f = base.in_features, base.out_featuresself.A = nn.Parameter(torch.empty(r, in_f))self.B = nn.Parameter(torch.zeros(out_f, r)) nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))self.scale = alpha / rdef forward(self, x):returnself.base(x) + (x @self.A.T @self.B.T) *self.scaledef merged_weight(self):"""Returns W + (alpha/r) BA — useful for inference without the adapter overhead."""returnself.base.weight +self.scale * (self.B @self.A)def add_lora(model, r=4, alpha=8, target=('qkv', 'proj')):"""Walk the model, replace target Linear layers with LoRALinear, freeze the rest."""for p in model.parameters(): p.requires_grad =False# Collect first to avoid mutating the module tree while iterating. targets = []for mod in model.modules():for child_name, child in mod.named_children():ifisinstance(child, nn.Linear) and child_name in target: targets.append((mod, child_name, child))for mod, child_name, child in targets:setattr(mod, child_name, LoRALinear(child, r=r, alpha=alpha))# Always train the classifier head (small, task-specific).for p in model.head.parameters(): p.requires_grad =Truereturn model
Code
torch.manual_seed(0)base_lora = TinyClassifier().to(DEVICE)add_lora(base_lora, r=4, alpha=8, target=('qkv', 'proj'))print('=== LoRA finetune on Task A ===')acc_lora_A = train(base_lora, tr, va, epochs=8, lr=5e-3)
5. QLoRA-style: int8-quantize the frozen base, train LoRA on top
Real QLoRA uses 4-bit nf4 quantization with double-quantization and paged optimizers. We’ll do the cleaner pedagogical version: per-tensor symmetric int8 for the frozen nn.Linear weights, dequantize on the fly during forward. The LoRA A, B stay in float32. The point is the same: the bulky frozen weights live in a low-precision format; only the tiny adapters move during training.
Code
class Int8Linear(nn.Module):"""Frozen Linear whose weight is stored as int8 + a per-tensor float scale."""def__init__(self, base: nn.Linear):super().__init__() w = base.weight.detach() scale = w.abs().max() /127.0self.register_buffer('q', (w / scale).round().clamp(-127, 127).to(torch.int8))self.register_buffer('scale', torch.tensor(scale.item(), dtype=torch.float32))self.bias = base.bias # bias stays float (small, no point quantizing)self.in_features = base.in_featuresself.out_features = base.out_features@propertydef weight(self):returnself.q.float() *self.scaledef forward(self, x):return F.linear(x, self.weight, self.bias)def quantize_to_int8(model, target=('qkv', 'proj')): targets = []for mod in model.modules():for child_name, child in mod.named_children():ifisinstance(child, nn.Linear) and child_name in target: targets.append((mod, child_name, child))for mod, child_name, child in targets:setattr(mod, child_name, Int8Linear(child))return modelclass LoRAOnQuantized(nn.Module):"""LoRA wrapper that accepts either a Linear or an Int8Linear as its frozen base."""def__init__(self, base, r: int=4, alpha: int=8):super().__init__()self.base = basefor p in base.parameters(): p.requires_grad =False in_f, out_f = base.in_features, base.out_featuresself.A = nn.Parameter(torch.empty(r, in_f))self.B = nn.Parameter(torch.zeros(out_f, r)) nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))self.scale = alpha / rdef forward(self, x):returnself.base(x) + (x @self.A.T @self.B.T) *self.scaledef add_qlora(model, r=4, alpha=8, target=('qkv', 'proj')): quantize_to_int8(model, target=target)for p in model.parameters(): p.requires_grad =False targets = []for mod in model.modules():for child_name, child in mod.named_children():ifisinstance(child, Int8Linear): targets.append((mod, child_name, child))for mod, child_name, child in targets:setattr(mod, child_name, LoRAOnQuantized(child, r=r, alpha=alpha))for p in model.head.parameters(): p.requires_grad =Truereturn model
Code
torch.manual_seed(0)base_qlora = TinyClassifier().to(DEVICE)# Pretend the freshly-initialised base is a 'pretrained' checkpoint and quantize it.add_qlora(base_qlora, r=4, alpha=8, target=('qkv', 'proj'))print('=== QLoRA-style finetune on Task A (int8 base + float LoRA) ===')acc_qlora_A = train(base_qlora, tr, va, epochs=8, lr=5e-3)
The interesting numbers: trainable parameter count and final accuracy.
Code
def trainable(m):returnsum(p.numel() for p in m.parameters() if p.requires_grad)rows = [ ('Full FT', trainable(base_full), acc_full_A), ('LoRA r=4', trainable(base_lora), acc_lora_A), ('QLoRA r=4 (int8 base)', trainable(base_qlora), acc_qlora_A),]header =f'{"method":24s}{"trainable":>10s}{"% of full":>10s}{"val acc":>8s}'print(header);print('-'*len(header))for name, n, acc in rows: pct =100* n / rows[0][1]print(f'{name:24s}{n:10,d}{pct:9.2f}% {acc:8.3f}')
method trainable % of full val acc
----------------------------------------------------------
Full FT 202,114 100.00% 0.980
LoRA r=4 6,274 3.10% 0.984
QLoRA r=4 (int8 base) 6,274 3.10% 0.992
7. Two tasks, two adapters, one base — adapter swap at inference
This is the punchline of LoRA in practice. Train a second adapter on Task B over the same frozen base. At inference, swap the adapter in/out depending on which task you want.
Code
# 1. Freeze a single shared base.torch.manual_seed(0)shared_base = TinyClassifier().to(DEVICE)shared_state = deepcopy(shared_base.state_dict()) # for re-instantiation# 2. Train Adapter A on Task A.model_A = TinyClassifier().to(DEVICE); model_A.load_state_dict(shared_state)add_lora(model_A, r=4, alpha=8, target=('qkv', 'proj'))trA = make_data('odd_majority', 1024, seed=1); vaA = make_data('odd_majority', 512, seed=2)print('--- Adapter A (Task A: odd majority) ---')acc_A_on_A = train(model_A, trA, vaA, epochs=8, lr=5e-3)# 3. Train Adapter B on Task B from the same base.model_B = TinyClassifier().to(DEVICE); model_B.load_state_dict(shared_state)add_lora(model_B, r=4, alpha=8, target=('qkv', 'proj'))trB = make_data('first_lt_last', 1024, seed=3); vaB = make_data('first_lt_last', 512, seed=4)print('--- Adapter B (Task B: first<last) ---')acc_B_on_B = train(model_B, trB, vaB, epochs=8, lr=5e-3)# 4. Cross-eval — does Adapter A do anything sensible on Task B and vice versa?def eval_on(m, val): xv, yv = (t.to(DEVICE) for t in val)with torch.no_grad():return (m(xv).argmax(-1) == yv).float().mean().item()print('\n=== Cross-evaluation ===')print(f'Adapter A on Task A: {acc_A_on_A:.3f}')print(f'Adapter A on Task B: {eval_on(model_A, vaB):.3f} (should be ~chance, 0.5)')print(f'Adapter B on Task B: {acc_B_on_B:.3f}')print(f'Adapter B on Task A: {eval_on(model_B, vaA):.3f} (should be ~chance, 0.5)')
--- Adapter A (Task A: odd majority) ---
trainable: 6,274 / 208,258 (3.01%)
epoch 1 loss=0.528 val_acc=0.787
epoch 2 loss=0.228 val_acc=0.938
epoch 3 loss=0.015 val_acc=0.990
epoch 4 loss=0.008 val_acc=0.963
epoch 5 loss=0.029 val_acc=0.939
epoch 6 loss=0.043 val_acc=0.994
epoch 7 loss=0.033 val_acc=0.980
epoch 8 loss=0.016 val_acc=0.961
done in 2.7s
--- Adapter B (Task B: first<last) ---
trainable: 6,274 / 208,258 (3.01%)
epoch 1 loss=0.721 val_acc=0.547
epoch 2 loss=0.699 val_acc=0.547
epoch 3 loss=0.605 val_acc=0.645
epoch 4 loss=0.426 val_acc=0.793
epoch 5 loss=0.115 val_acc=0.906
epoch 6 loss=0.168 val_acc=0.887
epoch 7 loss=0.129 val_acc=0.941
epoch 8 loss=0.072 val_acc=0.947
done in 2.5s
=== Cross-evaluation ===
Adapter A on Task A: 0.961
Adapter A on Task B: 0.523 (should be ~chance, 0.5)
Adapter B on Task B: 0.947
Adapter B on Task A: 0.506 (should be ~chance, 0.5)
Each adapter only solves the task it was trained on, as expected — they’re tiny task-specific deltas over a shared base.
8. “What if I train LoRA on Task A then continue on Task B?”
Two qualitatively different ways to handle a second supervised dataset. Both are useful; people pick based on whether they need one model that does both or two specialised models.
Option 1 — separate adapters (recommended default)
Keep the base frozen. Train one LoRA per task. At inference, swap the adapter:
Memory cost is one base + N small adapters (each <1% of base). Inference cost is unchanged. This is what we just did above. No interference.
Option 2 — sequential training on the same adapter
Continue training Adapter A on Task B. Below: train A, eval, then continue training A on B, eval again.
Code
torch.manual_seed(0)model_seq = TinyClassifier().to(DEVICE); model_seq.load_state_dict(shared_state)add_lora(model_seq, r=4, alpha=8, target=('qkv', 'proj'))print('--- Step 1: train shared adapter on Task A ---')_ = train(model_seq, trA, vaA, epochs=6, lr=5e-3, log=False)a1, b1 = eval_on(model_seq, vaA), eval_on(model_seq, vaB)print(f' after A: acc(A)={a1:.3f} acc(B)={b1:.3f}')print('--- Step 2: continue training the SAME adapter on Task B ---')_ = train(model_seq, trB, vaB, epochs=12, lr=5e-3, log=False)a2, b2 = eval_on(model_seq, vaA), eval_on(model_seq, vaB)print(f' after B: acc(A)={a2:.3f} acc(B)={b2:.3f} ← Task A drops: catastrophic forgetting')
--- Step 1: train shared adapter on Task A ---
after A: acc(A)=0.998 acc(B)=0.514
--- Step 2: continue training the SAME adapter on Task B ---
after B: acc(A)=0.490 acc(B)=0.930 ← Task A drops: catastrophic forgetting
The adapter that was good at Task A drifts toward Task B and forgets A. This is the same catastrophic-forgetting story you get with full finetuning; LoRA does nothing magical to prevent it.
Option 3 — joint training (multitask)
If you want one model that does both, mix A and B batches (with a task ID feature, or two heads) and train one adapter on the joint distribution. No forgetting, but the adapter has to be a bit larger to hold both behaviours.
Code
# Joint training: stack the two tasks. We need a way to tell A from B at the# input layer — the cleanest minimal hack is a single "task token" prepended.class TinyClassifier2Head(TinyClassifier):def__init__(self):super().__init__()self.task_emb = nn.Embedding(2, D)def forward(self, x_with_task): task = x_with_task[:, 0] x = x_with_task[:, 1:] B, T = x.shape cls = torch.full((B, 1), self.cls_id, device=x.device, dtype=x.dtype) x = torch.cat([cls, x], dim=1) pos = torch.arange(x.size(1), device=x.device) h =self.tok(x) +self.pos(pos)[None] h[:, 0] = h[:, 0] +self.task_emb(task) # tag the [CLS] with task idfor blk inself.blocks: h = blk(h)returnself.head(self.ln(h[:, 0]))torch.manual_seed(0)model_j = TinyClassifier2Head().to(DEVICE)add_lora(model_j, r=8, alpha=16, target=('qkv', 'proj'))xA, yA = trA; xB, yB = trBx_jt = torch.cat([torch.cat([torch.zeros(len(xA), 1, dtype=torch.long), xA], 1), torch.cat([torch.ones (len(xB), 1, dtype=torch.long), xB], 1)], 0)y_jt = torch.cat([yA, yB])perm = torch.randperm(len(x_jt))x_jt, y_jt = x_jt[perm], y_jt[perm]xA_v, yA_v = vaA; xB_v, yB_v = vaBvAj = (torch.cat([torch.zeros(len(xA_v), 1, dtype=torch.long), xA_v], 1), yA_v)vBj = (torch.cat([torch.ones (len(xB_v), 1, dtype=torch.long), xB_v], 1), yB_v)print('--- Joint training: one adapter, both tasks (task-id prepended) ---')_ = train(model_j, (x_jt, y_jt), vAj, epochs=10, lr=5e-3, log=False)print(f' joint acc on A: {eval_on(model_j, vAj):.3f}')print(f' joint acc on B: {eval_on(model_j, vBj):.3f}')
--- Joint training: one adapter, both tasks (task-id prepended) ---
joint acc on A: 0.947
joint acc on B: 0.932
Takeaways
LoRA: with r=4 we trained ~1–2% of the parameters and matched the full-finetune accuracy. The frozen base does the heavy lifting; the adapter learns the task-specific delta.
QLoRA: quantizing the frozen base to int8 (and to 4-bit nf4 in the real recipe) shrinks the static memory cost without hurting accuracy, because the moving parts are still float. The pedagogical pattern is the same as LoRA — only the storage of W changes.
Two supervised datasets:
Default to separate adapters per task. Keep the base frozen, train each adapter independently. Swap at inference. Tiny disk footprint per task, no interference, fully reversible.
Sequential training of one adapter behaves like full finetuning — the second task overwrites the first. Catastrophic forgetting is real with LoRA.
Joint training (mix batches, condition on task) gives you one adapter that does both, but you need a slightly larger rank and you need a way to tell the model which task to perform.
Where the real win lives. With a 7 B model and r=8 LoRA on QKV+MLP, a single adapter is ~5 MB. You can ship 100 adapters per base instead of 100 fully-finetuned 14 GB checkpoints.
If you want to play with this on a real model, the same LoRALinear class above swaps in unchanged for any nn.Linear in HuggingFace transformers — just walk model.named_modules() and replace q_proj, k_proj, v_proj, o_proj.