Skip to content

English · Español

Break 00 — Force GQA to kv_heads = 1 (extreme MQA collapse)

🇪🇸 Forzamos kv_heads = 1 en el bloque de atención GQA: las 4 cabezas de Mini-GPT comparten una sola clave/valor. La KV-cache se reduce 4×, pero todas las cabezas atienden con el mismo K/V — la diversidad de cabezas que entrenamos en Fase 15 colapsa. La PPL salta y las muestras se vuelven repetitivas.

This /break exercise targets the quality vs memory trade-off in GQA. The bug is one number; the failure is a measurable PPL hit plus a qualitative loss of head diversity.

Anchors: theory/04-gqa-mqa-mla.md, theory/05-flash-walkthrough-and-gqa-math.md, .claude/commands/break.md.


Hypothesis

The learner predicts: "Setting kv_heads = 1 (MQA) on a model trained with n_heads = 4 removes 3 of the 4 K, V projections. All four query heads now share one K, V. The model loses its ability to attend to different positions per head; the only diversity left is in Q and W_O. PPL goes up; sample diversity goes down."

The break

In src/minimodel/blocks/attention_gqa.py:

 class GQABlock(Module):
-    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2):
+    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1):  # /break: extreme MQA
         super().__init__()
         self.n_heads = n_heads
         self.kv_heads = kv_heads
         self.d_h = d_model // n_heads
         self.W_q = Linear(d_model, n_heads * self.d_h, bias=False)
         self.W_k = Linear(d_model, kv_heads * self.d_h, bias=False)
         self.W_v = Linear(d_model, kv_heads * self.d_h, bias=False)
         self.W_o = Linear(n_heads * self.d_h, d_model, bias=False)

One number changed: kv_heads: 2 → 1. The forward already handles the K, V broadcast across query groups; no other code edit is needed. This is the cleanest possible break for this concept.

Predict, then run

For Mini-GPT (n_heads = 4, kv_heads = 1 = MQA), the per-token KV cache drops from 2·n_layers·2·d_h·4 = 1024 B (GQA-2) to 2·n_layers·1·d_h·4 = 512 B (MQA). 2× more memory savings. But all 4 query heads now share a single K, V, so they attend to the same positions weighted only by the Q projection.

Predictions

  • PPL on §A13 eval set: noticeably higher than the GQA-2 baseline. Rough estimate: GQA-2 ≈ 5.20; MQA ≈ 5.55 (a 7% PPL hit, vs ~1-2% from GQA-2 over MHA — the extreme collapse is much worse than the moderate one).
  • Sample diversity: same prompt sampled 10 times produces 6-7 of the same continuation. The model becomes much more "confident" in a degenerate way because most query heads see the same attention pattern.
  • Per-head attention entropy: for the 4 heads in block 1, the entropy of the attention distribution converges within ~50% of the same value (vs varied entropies in the GQA-2 baseline).
  • KV-cache bytes: exactly half of the GQA-2 baseline. The memory win is real and not in dispute.

Write predictions in learners/borja/phase-27/notes/breaks.md before running.

Observe

Run the Phase 27 eval with the broken config:

just exp 27-gqa --variant mqa-kv1

Diagnostics to plot:

  1. PPL bar chart: mha (baseline), gqa-2 (Phase 27 default), mqa-kv1 (this break). Three bars; rightmost should be highest.
  2. Per-head attention entropy at layer 1: 4-line plot, one per query head. In MHA/GQA-2 the lines diverge; in MQA the 4 lines collapse onto each other after the first few positions (because they share K).
  3. Sample 10 continuations of "I work" at temperature 0.8; count distinct outputs. Baseline ≈ 7; broken ≈ 3.

Symptom Borja will see

  • PPL jumps by 5-10% on §A13 eval set.
  • Sample diversity collapses (most continuations identical).
  • KV-cache bytes drop by exactly 2× compared to GQA-2 (the memory win is real — that's the trade-off being illustrated).
  • No training-time crash; this is an inference-time architectural change.

Hidden cause (one sentence)

Setting kv_heads = 1 makes all n_heads query heads share a single K, V projection, eliminating per-head attention pattern diversity and collapsing the model to a single effective head's worth of routing capacity.

Hint cascade

  1. Plot attention entropy per head at layer 1. Are the heads still distinguishable?
  2. Print self.W_k.weight.shape. How many independent key projections does the model actually have?
  3. Re-read theory/04-gqa-mqa-mla.md §"GQA quality scaling". What does Ainslie et al. find as the safe kv_heads / n_heads ratio? What value did you set?

Fix diff

-class GQABlock(Module):
-    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1):
+class GQABlock(Module):
+    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2):  # restored: GQA-2 is the default

Or — if exploring further — re-run with kv_heads = 4 (full MHA) and kv_heads = 2 (GQA-2) to chart the actual Pareto frontier on Mini-GPT.

Why this teaches the concept

The original GQA paper (Ainslie et al., 2023) reports that kv_heads = n_heads / 8 is the safe operating point — below that, quality degrades visibly. On Mini-GPT, that ratio is 4/8 = 0.5, so even GQA-2 (kv_heads = 2) is at the boundary. Pushing to kv_heads = 1 is past the safe point, on purpose. This is the kind of trade-off where the abstract "MQA saves memory" line in a survey paper hides a real quality cost. Seeing the PPL number jump on a model you trained yourself makes the trade-off concrete in a way that no survey can.

Reference

  • Ainslie et al., GQA (arXiv:2305.13245), Figure 3 — quality vs kv_heads / n_heads.
  • Shazeer, Fast Transformer Decoding: One Write-Head is All You Need (arXiv:1911.02150) — the original MQA paper; reports a small but real quality hit on T5.

Next: restore GQA-2 and run lab/04-mqa-gqa.md for the full ablation.