English · Español
05 — Flash Forward Walkthrough + the GQA KV-Cache Math¶
🇪🇸 Dos derivaciones pegadas: (a) recorremos el forward de FlashAttention paso a paso con pseudocódigo tile-loop ejecutable mentalmente; (b) deducimos exactamente cuántos bytes de KV-cache ahorra GQA frente a MHA en función de
n_heads,kv_heads, secuencia y dtype. Sin pseudocódigo abreviado y sin proverbios.
Anchors: theory/01-online-softmax.md, theory/02-flash-attention.md, theory/04-gqa-mqa-mla.md. This file is the one you should keep open when you implement the lab.
Part A — Flash forward, fully unrolled¶
Sequence length N, head dim d, tile sizes B_r rows of Q and B_c rows of K, V. Constraint: one tile's working set fits SRAM.
Memory layout of the working set¶
The SRAM-resident state during the inner loop:
Q_i : (B_r, d) fp32 — current query tile
K_j : (B_c, d) fp32 — current key tile
V_j : (B_c, d) fp32 — current value tile
S_ij : (B_r, B_c) fp32 — partial logits, never written to HBM
m_i : (B_r,) fp32 — running row-max
l_i : (B_r,) fp32 — running row-sum (after rescaling)
O_i : (B_r, d) fp32 — running output accumulator
For d=64, B_r=B_c=64 in fp32: 4 × (64·64 + 3·64·64 + 64 + 64 + 64·64) = 4 × (4096 + 12288 + 128 + 4096) = 4 × 20608 ≈ 82 KiB. Tight on Haswell-class L1; comfortable on Hopper SRAM.
The block-by-block loop (executable mental model)¶
# Outer loop: each Q tile is processed once; K, V are streamed.
for i in range(ceil(N / B_r)):
Q_i = HBM_load(Q[i*B_r : (i+1)*B_r, :]) # (B_r, d), R: B_r * d
O_i = zeros((B_r, d)) # in SRAM
m_i = full((B_r,), -inf) # in SRAM, running max
l_i = zeros((B_r,)) # in SRAM, running sum
# Inner loop: stream K_j, V_j once per (i, j).
for j in range(ceil(N / B_c)):
K_j = HBM_load(K[j*B_c : (j+1)*B_c, :]) # R: B_c * d
V_j = HBM_load(V[j*B_c : (j+1)*B_c, :]) # R: B_c * d
# Compute partial logits in SRAM (never touches HBM).
S_ij = (Q_i @ K_j.T) / sqrt(d) # (B_r, B_c)
# Online softmax update — derived in theory 01.
m_new = max(m_i, row_max(S_ij)) # (B_r,)
alpha = exp(m_i - m_new) # (B_r,) — rescales old O_i, l_i
P_ij = exp(S_ij - m_new[:, None]) # (B_r, B_c)
l_i = alpha * l_i + row_sum(P_ij) # (B_r,)
O_i = alpha[:, None] * O_i + P_ij @ V_j # (B_r, d)
m_i = m_new
# End of inner loop: normalize and write back.
O[i*B_r : (i+1)*B_r, :] = O_i / l_i[:, None] # W: B_r * d
Why O_i rescales correctly (the load-bearing step)¶
At step j-1, O_i = sum_{j' < j} exp(s_{j'} - m_{j-1}) @ V_{j'}. After m_new:
The alpha = exp(m_i - m_new) factor is exactly what you need to "re-base" the old accumulator from m_i to m_new without recomputing the previous tiles. Identical algebra to theory 01's update() for the running sum.
HBM traffic accounting¶
Per outer iteration i:
- Load
Q_i:B_r · delements. - Load
K_jandV_jfor allj:2 · N · delements (each tile read once). - Write
O_i:B_r · d.
Total over all i:
For N=2048, d=64, B_r=64: 2048 · 64 · (2 + 64) = 8.6 M elements = 34.4 MiB fp32. Naive attention crosses HBM at ≈ 4 N² + 3 N d = 16.8 M + 0.4 M = 17.2 M elements = 68.8 MiB. Flash moves half the HBM bytes — and the S matrix never touches HBM at all, which is the bigger architectural win (cache hierarchy + DRAM read amplification).
Note for B_r → N (single outer iteration, all Q in SRAM): HBM = 2Nd + 2Nd = 4Nd. Flash collapses to a single-pass algorithm. For very small N (≤ 256) this is the regime; the tiling overhead disappears.
What changes for the backward pass (preview)¶
The backward needs S and P. Two strategies: (a) recompute S from Q, K (cheap FLOPs, cheap memory — chosen by Flash); (b) stash S in HBM (expensive memory, free FLOPs — chosen by naive). Flash 2's backward picks (a) with a couple of extra rescaling factors. Out of scope for this phase; the forward derivation alone is the load-bearing one.
Part B — GQA's KV-cache math, with numbers¶
The MHA baseline¶
For multi-head attention with n_heads heads, head dim d_h = d_model / n_heads, and context length N:
- KV cache shape:
(n_layers, 2, n_heads, N, d_h). - Bytes per token in the KV cache:
2 · n_layers · n_heads · d_h · sizeof(dtype) = 2 · n_layers · d_model · sizeof(dtype).
The "2" is one slot for K, one for V. Independent of n_heads once expressed as n_heads · d_h = d_model.
For LLaMA-7B (n_layers=32, d_model=4096) in fp16:
A N=4096 context: 4096 · 512 KiB = 2 GiB just for the KV cache. This is why long-context inference is hard.
GQA: group queries, share K and V¶
GQA partitions the n_heads heads into n_groups = n_heads / kv_heads groups, where each group of query heads shares one K, V head.
KV cache shape: (n_layers, 2, kv_heads, N, d_h).
Bytes per token:
The savings ratio:
For LLaMA-2-7B (n_heads=32, kv_heads=32) → MHA, no savings.
For LLaMA-2-70B (n_heads=64, kv_heads=8) → GQA-8, 8× reduction. KV/token drops from 1280 KiB (MHA-equivalent) to 160 KiB.
For Mistral-7B (n_heads=32, kv_heads=8) → GQA-8, 4× reduction.
Why GQA is not free in quality¶
The query heads in a group share one K, V — so they can only attend to the same locations, just with different "weights" via the Q projection. This is a real expressivity constraint. The empirical finding (Ainslie et al., 2023): on a well-tuned model, kv_heads = n_heads / 8 is essentially indistinguishable in quality from full MHA. Pushed further (kv_heads = 1 = MQA), the quality hit is real — most production models stop at n_heads / 8.
The cache-memory frontier¶
For a model serving B concurrent users at context N:
This is what GQA buys you: at fixed (B, N, n_layers, d_h), going from kv_heads = n_heads to kv_heads = n_heads / 8 lets you serve 8× more users in the same KV-cache budget. Or serve the same users at 8× longer context. Or use 8× cheaper hardware.
The KV cache is the practical bottleneck at production inference scale. Attention FLOPs scale with N²; KV-cache memory scales with N. As models go to 100K-token contexts, KV-cache memory dominates everything else. GQA is the single biggest architectural lever to manage that, short of dropping caching entirely (LoRA, infinite-context tricks, etc.).
Mini-GPT-scale numbers (so the derivation is not abstract)¶
For Mini-GPT (n_layers=2, n_heads=4, d_h=16, d_model=64, fp32) at N=64:
| Variant | kv_heads | KV bytes/token | Total KV bytes at N=64 |
|---|---|---|---|
| MHA | 4 | 2·2·4·16·4 = 1024 |
65 536 (64 KiB) |
| GQA-2 | 2 | 2·2·2·16·4 = 512 |
32 768 (32 KiB) |
| MQA | 1 | 2·2·1·16·4 = 256 |
16 384 (16 KiB) |
A 4× reduction from MHA to MQA — small absolute, but it's the same ratio that scales to a gigabyte for LLaMA-2-70B.
Citations¶
- Dao, Fu, Ermon, Rudra, Ré. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. arXiv:2205.14135.
- Dao. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. 2023. arXiv:2307.08691.
- Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón, Sanghai. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. arXiv:2305.13245.
One-paragraph recap¶
The Flash forward is a double loop: outer Q tile, inner K, V tile, with the online-softmax recurrence in the inner step that keeps the partial logits matrix S SRAM-resident and avoids HBM round-trips. The HBM traffic drops from O(N²) to O(N²d / B_r) — half-to-third in practice — and S never crosses HBM at all. GQA, in parallel, attacks the KV-cache memory: by sharing K, V across groups of query heads, it cuts KV bytes/token by exactly kv_heads / n_heads. The two techniques compose: Flash speeds up attention; GQA shrinks the cache. Together they make long-context inference (10K-100K tokens) economically feasible.
Next: lab/04-mqa-gqa.md for the empirical measurement.