English · Español
Break 00 — Force GQA to kv_heads = 1 (extreme MQA collapse)¶
🇪🇸 Forzamos
kv_heads = 1en el bloque de atención GQA: las 4 cabezas de Mini-GPT comparten una sola clave/valor. La KV-cache se reduce 4×, pero todas las cabezas atienden con el mismoK/V— la diversidad de cabezas que entrenamos en Fase 15 colapsa. La PPL salta y las muestras se vuelven repetitivas.
This /break exercise targets the quality vs memory trade-off in GQA. The bug is one number; the failure is a measurable PPL hit plus a qualitative loss of head diversity.
Anchors: theory/04-gqa-mqa-mla.md, theory/05-flash-walkthrough-and-gqa-math.md, .claude/commands/break.md.
Hypothesis¶
The learner predicts: "Setting kv_heads = 1 (MQA) on a model trained with n_heads = 4 removes 3 of the 4 K, V projections. All four query heads now share one K, V. The model loses its ability to attend to different positions per head; the only diversity left is in Q and W_O. PPL goes up; sample diversity goes down."
The break¶
In src/minimodel/blocks/attention_gqa.py:
class GQABlock(Module):
- def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2):
+ def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1): # /break: extreme MQA
super().__init__()
self.n_heads = n_heads
self.kv_heads = kv_heads
self.d_h = d_model // n_heads
self.W_q = Linear(d_model, n_heads * self.d_h, bias=False)
self.W_k = Linear(d_model, kv_heads * self.d_h, bias=False)
self.W_v = Linear(d_model, kv_heads * self.d_h, bias=False)
self.W_o = Linear(n_heads * self.d_h, d_model, bias=False)
One number changed: kv_heads: 2 → 1. The forward already handles the K, V broadcast across query groups; no other code edit is needed. This is the cleanest possible break for this concept.
Predict, then run¶
For Mini-GPT (n_heads = 4, kv_heads = 1 = MQA), the per-token KV cache drops from 2·n_layers·2·d_h·4 = 1024 B (GQA-2) to 2·n_layers·1·d_h·4 = 512 B (MQA). 2× more memory savings. But all 4 query heads now share a single K, V, so they attend to the same positions weighted only by the Q projection.
Predictions¶
- PPL on §A13 eval set: noticeably higher than the GQA-2 baseline. Rough estimate: GQA-2 ≈ 5.20; MQA ≈ 5.55 (a 7% PPL hit, vs ~1-2% from GQA-2 over MHA — the extreme collapse is much worse than the moderate one).
- Sample diversity: same prompt sampled 10 times produces 6-7 of the same continuation. The model becomes much more "confident" in a degenerate way because most query heads see the same attention pattern.
- Per-head attention entropy: for the 4 heads in block 1, the entropy of the attention distribution converges within ~50% of the same value (vs varied entropies in the GQA-2 baseline).
- KV-cache bytes: exactly half of the GQA-2 baseline. The memory win is real and not in dispute.
Write predictions in learners/borja/phase-27/notes/breaks.md before running.
Observe¶
Run the Phase 27 eval with the broken config:
Diagnostics to plot:
- PPL bar chart:
mha(baseline),gqa-2(Phase 27 default),mqa-kv1(this break). Three bars; rightmost should be highest. - Per-head attention entropy at layer 1: 4-line plot, one per query head. In MHA/GQA-2 the lines diverge; in MQA the 4 lines collapse onto each other after the first few positions (because they share
K). - Sample 10 continuations of "I work" at temperature 0.8; count distinct outputs. Baseline ≈ 7; broken ≈ 3.
Symptom Borja will see¶
- PPL jumps by 5-10% on §A13 eval set.
- Sample diversity collapses (most continuations identical).
- KV-cache bytes drop by exactly 2× compared to GQA-2 (the memory win is real — that's the trade-off being illustrated).
- No training-time crash; this is an inference-time architectural change.
Hidden cause (one sentence)¶
Setting kv_heads = 1 makes all n_heads query heads share a single K, V projection, eliminating per-head attention pattern diversity and collapsing the model to a single effective head's worth of routing capacity.
Hint cascade¶
- Plot attention entropy per head at layer 1. Are the heads still distinguishable?
- Print
self.W_k.weight.shape. How many independent key projections does the model actually have? - Re-read
theory/04-gqa-mqa-mla.md§"GQA quality scaling". What does Ainslie et al. find as the safekv_heads / n_headsratio? What value did you set?
Fix diff¶
-class GQABlock(Module):
- def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1):
+class GQABlock(Module):
+ def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2): # restored: GQA-2 is the default
Or — if exploring further — re-run with kv_heads = 4 (full MHA) and kv_heads = 2 (GQA-2) to chart the actual Pareto frontier on Mini-GPT.
Why this teaches the concept¶
The original GQA paper (Ainslie et al., 2023) reports that kv_heads = n_heads / 8 is the safe operating point — below that, quality degrades visibly. On Mini-GPT, that ratio is 4/8 = 0.5, so even GQA-2 (kv_heads = 2) is at the boundary. Pushing to kv_heads = 1 is past the safe point, on purpose. This is the kind of trade-off where the abstract "MQA saves memory" line in a survey paper hides a real quality cost. Seeing the PPL number jump on a model you trained yourself makes the trade-off concrete in a way that no survey can.
Reference¶
- Ainslie et al., GQA (arXiv:2305.13245), Figure 3 — quality vs
kv_heads / n_heads. - Shazeer, Fast Transformer Decoding: One Write-Head is All You Need (arXiv:1911.02150) — the original MQA paper; reports a small but real quality hit on T5.
Next: restore GQA-2 and run lab/04-mqa-gqa.md for the full ablation.