English · Español
04 — GQA, MQA, MLA: Sharing K and V¶
🇪🇸 Tres trucos para reducir el tamaño del KV cache compartiendo K y V entre cabezas. MQA: una sola K/V para todas. GQA: una K/V por grupo. MLA: K/V comprimidos a un espacio latente. Cada uno es una intervención sobre los bytes que cruzan memoria por token decodificado.
The KV cache as a roofline term¶
During autoregressive decode (one token at a time), per-step attention works as follows:
- The new token's
Qhas shape(1, n_heads, d). - The cached
K, Vhave shape(N, n_heads, d)whereNis the current sequence length. - For each head, compute
softmax((Q · K^T) / √d) · V.
The compute is O(N · d) per head per step — tiny. The bottleneck is loading K and V from HBM. For batch=1, single-token-decode attention is extremely bandwidth-bound: the FPU does maybe 2 N d FLOPs while 8 N d bytes (fp16) move from HBM. Intensity I ≈ 0.25 FLOPs/byte. Memory-bound by orders of magnitude.
The KV cache size is the dominant per-step memory traffic.
If we shrink KV cache, per-step bytes drop proportionally, intensity rises proportionally, decode tokens/sec rises proportionally. This is the entire premise of GQA/MQA/MLA.
Multi-Query Attention (MQA, Shazeer 2019)¶
The change: all attention heads share a single K, V pair.
- Standard MHA:
K, V ∈ ℝ^{N × n_heads × d}. KV cache bytes =2 × N × n_heads × d × 2 (fp16) = 4 N · n_heads · d. - MQA:
K, V ∈ ℝ^{N × 1 × d}. KV cache bytes =4 N d.n_heads× smaller.
The math: each query head i computes softmax(Q_i K^T / √d) V, where K, V are the same for all i. The flexibility of independent K per head is given up; in exchange, the KV cache shrinks by the number of heads (typically 32×).
Quality cost: MQA models are usually trained from scratch (not retrofitted from MHA). Quality is slightly worse than MHA at the same parameter count — but the model can be made bigger with the saved memory, so the trade is net-positive in practice. PaLM, Falcon, and others use MQA.
Grouped-Query Attention (GQA, Ainslie 2023)¶
The change: heads are grouped; each group shares one K, V pair.
n_groups = n_heads / group_size. Typicallygroup_size = 8→ 4× reduction.K, V ∈ ℝ^{N × n_groups × d}. KV cache bytes =4 N · n_groups · d.group_size× smaller.
GQA is a Pareto improvement on the MHA-MQA spectrum. With 32 heads and group_size=8, KV cache is 4× smaller (between MHA's 32 and MQA's 1), and quality is closer to MHA. LLaMA 2 70B uses GQA with group_size=8. Mistral 7B uses GQA with group_size=4.
Retrofit recipe: to convert an MHA-trained model to GQA, average the K and V projection weights within each group, then continue training for a few steps to recover quality. Cheap.
Multi-Latent Attention (MLA, DeepSeek-V2 2024)¶
The change: project K, V to a low-rank latent space; cache the latent; reconstruct on the fly.
- Per-token KV state: a single low-rank vector
c ∈ ℝ^{d_c}withd_c < d × n_heads. - At attention time, reconstruct
K, Vfromcvia learned projection matrices.
The KV cache becomes N × d_c instead of N × n_heads × d. For DeepSeek-V2: d_c ≈ 512, n_heads × d ≈ 16384. 32× reduction vs MHA. Substantially better than MQA's "1 KV head" because the latent space is high-quality (chosen by training, not arbitrary averaging).
Cost: the reconstruction projection adds compute per attention call. Mathematically, MLA composes with Flash — the projection is a small additional matmul inside the kernel. The DeepSeek paper reports negligible overhead.
Quality: MLA matches or beats MHA at the same effective compute budget while having 30× smaller KV. The MLA paper is one of 2024's most important attention papers.
Where these compose with Flash¶
Flash Attention tiles Q, K, V and processes attention in SRAM. GQA/MQA/MLA shrink K, V. They compose freely: Flash with GQA = Flash where each Q-tile's inner loop accesses shared K/V tiles (with K/V indexed by group, not by head). Standard implementations (FlashAttn-v2) support this.
Where these compose with Paged¶
PagedAttention pages KV by (layer, head, position). With GQA, paging is per group, not per head. With MLA, paging is over the latent vectors, not K/V. The block_manager logic generalizes; vLLM's MLA support is recent (mid-2024 onward).
Roofline numbers¶
For a 7B model at decode with N=4096:
| Variant | Heads × d_per | KV bytes/token | Decode tokens/sec proxy |
|---|---|---|---|
| MHA | 32 × 128 = 4096 | 16 KiB | 1× baseline |
| GQA-8 | 32 query heads, 8 KV groups (4 query heads share each KV pair), d=128 | 2 KiB | ~7–8× |
| MQA | 1 × 128 = 128 | 0.5 KiB | ~30× |
| MLA (d_c=512) | latent 512 | 1 KiB | ~16× |
(The "tokens/sec proxy" is the inverse of bytes per token, assuming bandwidth-bound decode and ignoring per-step compute. Real measurements vary, but the ordering is correct.)
Sliding window vs KV reduction — the wrong question¶
A common confusion: aren't sliding window and KV reduction the same? They reduce KV.
No:
- Sliding window reduces the number of positions attended to: each query sees W positions, not N.
- GQA/MQA/MLA reduce the per-position KV size: each position contributes fewer bytes.
They're orthogonal axes:
| Full positions | Sliding window W | |
|---|---|---|
| Standard MHA | N · n_heads · d KV | W · n_heads · d KV |
| MQA | N · d KV | W · d KV |
| MLA | N · d_c KV | W · d_c KV |
Mistral 7B uses GQA-4 + sliding window 4096. Together: KV per step ~16× smaller than vanilla LLaMA, plus O(W) instead of O(N) attention.
What we measure on Borja's roofline¶
For the roofline overlay (experiment 27-roofline-overlay), we plot:
- Naive MHA + naive attention: lowest intensity, on the memory ceiling.
- Naive MHA + Flash: higher intensity, partway up.
- GQA + Flash: even higher intensity (smaller bytes-per-K-tile load).
- MQA + Flash: highest single-kernel intensity.
This is a single-kernel analysis. The system-level wins from PagedAttention require server-level measurements out of scope for Phase 27.
Drill problems¶
Solutions at phase open in solutions/04-gqa-mqa-mla-ref.md.
- For MHA with 32 heads, group_size=8 GQA. The K projection weight matrix shape changes from
(hidden, n_heads × d)to(hidden, n_groups × d). By what factor do the K and V projection parameters shrink? - MQA's KV cache is
4 N dbytes (fp16). For LLaMA-7B (d=128, n_layers=32, N=4096), compute KV cache size. Compare to MHA (n_heads=32). - MLA's reconstruction
K = W_K_up · c, wherec ∈ ℝ^{d_c},K ∈ ℝ^{n_heads × d},W_K_up ∈ ℝ^{(n_heads d) × d_c}. Compute the per-attention-step extra FLOPs from the reconstruction. Argue when this is negligible vs the dominant compute. - GQA-8 reduces KV cache by 4× vs MHA. Show that in a bandwidth-bound decode regime this implies ~4× throughput improvement, ignoring other terms.
One-paragraph recap¶
GQA, MQA, and MLA all reduce KV cache size by sharing or compressing K and V across heads. MQA shares one K/V across all heads (max compression, max quality cost). GQA shares within groups (Pareto sweet spot). MLA compresses to a low-rank latent space (best quality at strong compression). Each composes with Flash and with PagedAttention, multiplying throughput gains. The unifying lens is the roofline: decode is bandwidth-bound by KV traffic, so shrinking KV directly raises intensity. Modern inference engines layer Flash + Paged + GQA + sliding window simultaneously, each attacking a different cost. With Phase 27's theory complete, the labs implement Flash forward in Triton and annotate vLLM's KV-cache allocator.
Next: lab/00-online-softmax.md.