English · Español
Lab 04 — MQA / GQA and the KV-Cache Win¶
Goal: implement Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) as drop-in variants of MiniGPT's attention, and measure the KV-cache size reduction on the verb-corpus 64-token sequence.
Estimated time: 4–6 hours.
Prereq: theory 04 read;
src/minimodel/attention.py(Phase 15) familiar; KV-cache from Phase 22 understood.
What you produce¶
src/minimodel/attention_mqa_gqa.py — the variant attention module.
experiments/27-mqa-gqa-kv-cache/ containing:
bench.py— runs MHA / GQA / MQA forward on the same verb-corpus 64-token sequence and measures: (a) KV-cache bytes per generated token, (b) output PPL drift vs MHA reference.results.json— measurements.kv_bytes.png— bar chart, bytes per token for each variant.manifest.json.README.md— interpretation.
Tests in tests/test_minimodel_mqa_gqa.py (Claude scaffolds failing).
The math (recap from theory 04)¶
Standard Multi-Head Attention (MHA) has n_heads independent Q, K, V projections. KV cache per token = 2 × n_heads × d_head × bytes_per_element.
Multi-Query Attention (MQA): all heads share one K and one V. Cache = 2 × 1 × d_head × bytes_per_element. n_heads× reduction.
Grouped-Query Attention (GQA): heads are split into n_kv_heads groups; each group shares K and V. Cache = 2 × n_kv_heads × d_head × bytes_per_element. n_heads / n_kv_heads× reduction.
For MiniGPT with n_heads = 4:
- MHA cache per token = 2 × 4 × d_head × 2 bytes (fp16) = 16 × d_head bytes.
- GQA with n_kv_heads = 2: 2 × 2 × d_head × 2 = 8 × d_head bytes. 2× reduction.
- MQA: 2 × 1 × d_head × 2 = 4 × d_head bytes. 4× reduction.
For d_head = 16 (typical MiniGPT), the per-token cache savings are: 256 → 128 → 64 bytes. On a 64-token sequence, total cache: 16 KiB → 8 KiB → 4 KiB. Small absolute numbers, but the ratio is what generalizes to production.
TODOs¶
Block A — design the module¶
-
class AttentionVariant(nn.Module)accepts a config object withd_model, n_heads, n_kv_heads, d_head, dropout, causal. - When
n_kv_heads == n_heads, behaves identically to MHA (Phase 15). Whenn_kv_heads == 1, is MQA. Otherwise GQA. - The Q projection produces
n_heads × d_headper token; K and V projections producen_kv_heads × d_headper token. - Attention: each Q head looks up the corresponding group's K, V (broadcast / repeat-interleave the KV heads to match Q heads inside the matmul). Stay numerically equivalent to MHA when
n_kv_heads == n_heads.
Block B — implement and test¶
- Forward pass shape:
(B, T, d_model) → (B, T, d_model). Causal mask supported. - Test: with
n_kv_heads = n_heads, output matchessrc/minimodel/attention.py(Phase 15) to1e-5. - Test: with
n_kv_heads = 1, parameter count drops by(2 × (n_heads - 1) × d_head × d_model)(the K and V projection weight savings). - Test: KV cache tensor allocated by the module has the expected shape
(B, T, n_kv_heads, d_head)for K and same for V.
Block C — measure on the verb-corpus sequence¶
- Load MiniGPT (Phase 17 checkpoint) plus a verb-corpus held-out sentence of length 64 tokens.
- Build three variants of MiniGPT, swapping only the attention module: MHA / GQA (
n_kv_heads = 2) / MQA (n_kv_heads = 1). - For each: run forward; measure KV cache total bytes; compute output PPL on the next-token distribution.
- Note: PPL will likely degrade under MQA/GQA because the head sharing changes the model's representational capacity. This is expected. Phase 28 (LoRA) is where you would re-tune the variants to recover; this lab just measures the trade-off.
Block D — interpret in README.md¶
Four questions:
- What KV-cache size reduction did you measure for MQA vs MHA? Expect ≈
n_heads× (4× forn_heads=4). - What's the PPL gap MQA vs MHA on the verb-corpus eval? Expect 5–20% degradation if the model was not trained with MQA in mind (which is the case here — we're swapping at inference time).
- At what production scale does MQA become a no-brainer? Compute KV cache for a 32-layer / 32-head / 128-d_head model at sequence length 8192. Compare MHA vs MQA total cache bytes. The number should be in GiB; that's the answer.
- What's the relationship between MQA and Flash Attention? Are they orthogonal? Composable? (Hint: yes to both — MQA reduces K/V bytes, Flash reduces the materialized S matrix. Different bottlenecks.)
Constraints¶
- No retraining. This is a structural swap at inference time. Measure the degradation honestly.
- CPU is fine. The verb sequence is short; no GPU needed for this lab. (Phase 27's GPU labs are 01 and 02.)
- Same MiniGPT checkpoint for all three variants — only the attention module changes.
🇪🇸 MQA/GQA no son aceleraciones gratis — comparten K/V entre cabezas y eso reduce capacidad. La elección real es: ¿cuánto PPL puedes ceder a cambio de cuánto KV cache? En modelos pequeños tipo MiniGPT, MQA es agresivo. En 70B+, casi todos los modelos modernos usan GQA.
Stop conditions¶
src/minimodel/attention_mqa_gqa.pyimplemented; tests pass.- All five experiment files committed.
- KV-cache reduction matches the theoretical formula to within 1%.
- README answers all four questions.
Pitfalls¶
- Forgot to repeat-interleave KV before matmul. If Q has 4 heads and KV has 1 head (MQA), you need to expand KV to 4 heads (via
K.repeat_interleave(4, dim=heads_dim)or einsum equivalent) before the dot product. Otherwise shapes don't match. - PPL didn't change at all. You probably did not actually swap the attention — check
id(model.layers[0].attn)is different from the MHA reference. - Parameter count drop doesn't match. You may have left FP16 KV biases at the same size; ignore biases or compute weight-only param count.
When to consult solutions/¶
After all stop conditions met. solutions/04-mqa-gqa-ref.md (phase open) walks through the repeat-interleave logic.
End of Phase 27 labs. Write PHASE_27_REPORT.md next.