Skip to content

English · Español

Lab 04 — MQA / GQA and the KV-Cache Win

Goal: implement Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) as drop-in variants of MiniGPT's attention, and measure the KV-cache size reduction on the verb-corpus 64-token sequence.

Estimated time: 4–6 hours.

Prereq: theory 04 read; src/minimodel/attention.py (Phase 15) familiar; KV-cache from Phase 22 understood.


What you produce

src/minimodel/attention_mqa_gqa.py — the variant attention module.

experiments/27-mqa-gqa-kv-cache/ containing:

  • bench.py — runs MHA / GQA / MQA forward on the same verb-corpus 64-token sequence and measures: (a) KV-cache bytes per generated token, (b) output PPL drift vs MHA reference.
  • results.json — measurements.
  • kv_bytes.png — bar chart, bytes per token for each variant.
  • manifest.json.
  • README.md — interpretation.

Tests in tests/test_minimodel_mqa_gqa.py (Claude scaffolds failing).

The math (recap from theory 04)

Standard Multi-Head Attention (MHA) has n_heads independent Q, K, V projections. KV cache per token = 2 × n_heads × d_head × bytes_per_element.

Multi-Query Attention (MQA): all heads share one K and one V. Cache = 2 × 1 × d_head × bytes_per_element. n_heads× reduction.

Grouped-Query Attention (GQA): heads are split into n_kv_heads groups; each group shares K and V. Cache = 2 × n_kv_heads × d_head × bytes_per_element. n_heads / n_kv_heads× reduction.

For MiniGPT with n_heads = 4: - MHA cache per token = 2 × 4 × d_head × 2 bytes (fp16) = 16 × d_head bytes. - GQA with n_kv_heads = 2: 2 × 2 × d_head × 2 = 8 × d_head bytes. 2× reduction. - MQA: 2 × 1 × d_head × 2 = 4 × d_head bytes. 4× reduction.

For d_head = 16 (typical MiniGPT), the per-token cache savings are: 256 → 128 → 64 bytes. On a 64-token sequence, total cache: 16 KiB → 8 KiB → 4 KiB. Small absolute numbers, but the ratio is what generalizes to production.

TODOs

Block A — design the module

  • class AttentionVariant(nn.Module) accepts a config object with d_model, n_heads, n_kv_heads, d_head, dropout, causal.
  • When n_kv_heads == n_heads, behaves identically to MHA (Phase 15). When n_kv_heads == 1, is MQA. Otherwise GQA.
  • The Q projection produces n_heads × d_head per token; K and V projections produce n_kv_heads × d_head per token.
  • Attention: each Q head looks up the corresponding group's K, V (broadcast / repeat-interleave the KV heads to match Q heads inside the matmul). Stay numerically equivalent to MHA when n_kv_heads == n_heads.

Block B — implement and test

  • Forward pass shape: (B, T, d_model) → (B, T, d_model). Causal mask supported.
  • Test: with n_kv_heads = n_heads, output matches src/minimodel/attention.py (Phase 15) to 1e-5.
  • Test: with n_kv_heads = 1, parameter count drops by (2 × (n_heads - 1) × d_head × d_model) (the K and V projection weight savings).
  • Test: KV cache tensor allocated by the module has the expected shape (B, T, n_kv_heads, d_head) for K and same for V.

Block C — measure on the verb-corpus sequence

  • Load MiniGPT (Phase 17 checkpoint) plus a verb-corpus held-out sentence of length 64 tokens.
  • Build three variants of MiniGPT, swapping only the attention module: MHA / GQA (n_kv_heads = 2) / MQA (n_kv_heads = 1).
  • For each: run forward; measure KV cache total bytes; compute output PPL on the next-token distribution.
  • Note: PPL will likely degrade under MQA/GQA because the head sharing changes the model's representational capacity. This is expected. Phase 28 (LoRA) is where you would re-tune the variants to recover; this lab just measures the trade-off.

Block D — interpret in README.md

Four questions:

  1. What KV-cache size reduction did you measure for MQA vs MHA? Expect ≈ n_heads× (4× for n_heads=4).
  2. What's the PPL gap MQA vs MHA on the verb-corpus eval? Expect 5–20% degradation if the model was not trained with MQA in mind (which is the case here — we're swapping at inference time).
  3. At what production scale does MQA become a no-brainer? Compute KV cache for a 32-layer / 32-head / 128-d_head model at sequence length 8192. Compare MHA vs MQA total cache bytes. The number should be in GiB; that's the answer.
  4. What's the relationship between MQA and Flash Attention? Are they orthogonal? Composable? (Hint: yes to both — MQA reduces K/V bytes, Flash reduces the materialized S matrix. Different bottlenecks.)

Constraints

  • No retraining. This is a structural swap at inference time. Measure the degradation honestly.
  • CPU is fine. The verb sequence is short; no GPU needed for this lab. (Phase 27's GPU labs are 01 and 02.)
  • Same MiniGPT checkpoint for all three variants — only the attention module changes.

🇪🇸 MQA/GQA no son aceleraciones gratis — comparten K/V entre cabezas y eso reduce capacidad. La elección real es: ¿cuánto PPL puedes ceder a cambio de cuánto KV cache? En modelos pequeños tipo MiniGPT, MQA es agresivo. En 70B+, casi todos los modelos modernos usan GQA.

Stop conditions

  • src/minimodel/attention_mqa_gqa.py implemented; tests pass.
  • All five experiment files committed.
  • KV-cache reduction matches the theoretical formula to within 1%.
  • README answers all four questions.

Pitfalls

  • Forgot to repeat-interleave KV before matmul. If Q has 4 heads and KV has 1 head (MQA), you need to expand KV to 4 heads (via K.repeat_interleave(4, dim=heads_dim) or einsum equivalent) before the dot product. Otherwise shapes don't match.
  • PPL didn't change at all. You probably did not actually swap the attention — check id(model.layers[0].attn) is different from the MHA reference.
  • Parameter count drop doesn't match. You may have left FP16 KV biases at the same size; ignore biases or compute weight-only param count.

When to consult solutions/

After all stop conditions met. solutions/04-mqa-gqa-ref.md (phase open) walks through the repeat-interleave logic.


End of Phase 27 labs. Write PHASE_27_REPORT.md next.