Skip to content

English · Español

05 — Flash Forward Walkthrough + the GQA KV-Cache Math

🇪🇸 Dos derivaciones pegadas: (a) recorremos el forward de FlashAttention paso a paso con pseudocódigo tile-loop ejecutable mentalmente; (b) deducimos exactamente cuántos bytes de KV-cache ahorra GQA frente a MHA en función de n_heads, kv_heads, secuencia y dtype. Sin pseudocódigo abreviado y sin proverbios.

Anchors: theory/01-online-softmax.md, theory/02-flash-attention.md, theory/04-gqa-mqa-mla.md. This file is the one you should keep open when you implement the lab.


Part A — Flash forward, fully unrolled

Sequence length N, head dim d, tile sizes B_r rows of Q and B_c rows of K, V. Constraint: one tile's working set fits SRAM.

Memory layout of the working set

The SRAM-resident state during the inner loop:

Q_i  : (B_r, d) fp32       — current query tile
K_j  : (B_c, d) fp32       — current key tile
V_j  : (B_c, d) fp32       — current value tile
S_ij : (B_r, B_c) fp32     — partial logits, never written to HBM
m_i  : (B_r,) fp32         — running row-max
l_i  : (B_r,) fp32         — running row-sum (after rescaling)
O_i  : (B_r, d) fp32       — running output accumulator

For d=64, B_r=B_c=64 in fp32: 4 × (64·64 + 3·64·64 + 64 + 64 + 64·64) = 4 × (4096 + 12288 + 128 + 4096) = 4 × 20608 ≈ 82 KiB. Tight on Haswell-class L1; comfortable on Hopper SRAM.

The block-by-block loop (executable mental model)

# Outer loop: each Q tile is processed once; K, V are streamed.
for i in range(ceil(N / B_r)):
    Q_i  = HBM_load(Q[i*B_r : (i+1)*B_r, :])           # (B_r, d), R: B_r * d
    O_i  = zeros((B_r, d))                              # in SRAM
    m_i  = full((B_r,), -inf)                           # in SRAM, running max
    l_i  = zeros((B_r,))                                # in SRAM, running sum

    # Inner loop: stream K_j, V_j once per (i, j).
    for j in range(ceil(N / B_c)):
        K_j  = HBM_load(K[j*B_c : (j+1)*B_c, :])       # R: B_c * d
        V_j  = HBM_load(V[j*B_c : (j+1)*B_c, :])       # R: B_c * d

        # Compute partial logits in SRAM (never touches HBM).
        S_ij = (Q_i @ K_j.T) / sqrt(d)                  # (B_r, B_c)

        # Online softmax update — derived in theory 01.
        m_new = max(m_i, row_max(S_ij))                 # (B_r,)
        alpha = exp(m_i - m_new)                        # (B_r,) — rescales old O_i, l_i
        P_ij  = exp(S_ij - m_new[:, None])              # (B_r, B_c)
        l_i   = alpha * l_i + row_sum(P_ij)             # (B_r,)
        O_i   = alpha[:, None] * O_i + P_ij @ V_j       # (B_r, d)
        m_i   = m_new

    # End of inner loop: normalize and write back.
    O[i*B_r : (i+1)*B_r, :] = O_i / l_i[:, None]        # W: B_r * d

Why O_i rescales correctly (the load-bearing step)

At step j-1, O_i = sum_{j' < j} exp(s_{j'} - m_{j-1}) @ V_{j'}. After m_new:

\[ O_i^{\text{new}} = \sum_{j' \le j} \exp(s_{j'} - m_{\text{new}}) \, V_{j'} = \alpha \cdot O_i^{\text{old}} + \exp(S_{ij} - m_{\text{new}}) \, V_j \]

The alpha = exp(m_i - m_new) factor is exactly what you need to "re-base" the old accumulator from m_i to m_new without recomputing the previous tiles. Identical algebra to theory 01's update() for the running sum.

HBM traffic accounting

Per outer iteration i:

  • Load Q_i: B_r · d elements.
  • Load K_j and V_j for all j: 2 · N · d elements (each tile read once).
  • Write O_i: B_r · d.

Total over all i:

\[ \text{HBM elements} = \underbrace{N \cdot d}_{\text{Q, once total}} + \underbrace{(N/B_r) \cdot 2 \cdot N \cdot d}_{\text{K, V re-read per outer tile}} + \underbrace{N \cdot d}_{\text{O, once total}} = N d \cdot \left(2 + \frac{2N}{B_r}\right) \]

For N=2048, d=64, B_r=64: 2048 · 64 · (2 + 64) = 8.6 M elements = 34.4 MiB fp32. Naive attention crosses HBM at ≈ 4 N² + 3 N d = 16.8 M + 0.4 M = 17.2 M elements = 68.8 MiB. Flash moves half the HBM bytes — and the S matrix never touches HBM at all, which is the bigger architectural win (cache hierarchy + DRAM read amplification).

Note for B_r → N (single outer iteration, all Q in SRAM): HBM = 2Nd + 2Nd = 4Nd. Flash collapses to a single-pass algorithm. For very small N (≤ 256) this is the regime; the tiling overhead disappears.

What changes for the backward pass (preview)

The backward needs S and P. Two strategies: (a) recompute S from Q, K (cheap FLOPs, cheap memory — chosen by Flash); (b) stash S in HBM (expensive memory, free FLOPs — chosen by naive). Flash 2's backward picks (a) with a couple of extra rescaling factors. Out of scope for this phase; the forward derivation alone is the load-bearing one.


Part B — GQA's KV-cache math, with numbers

The MHA baseline

For multi-head attention with n_heads heads, head dim d_h = d_model / n_heads, and context length N:

  • KV cache shape: (n_layers, 2, n_heads, N, d_h).
  • Bytes per token in the KV cache: 2 · n_layers · n_heads · d_h · sizeof(dtype) = 2 · n_layers · d_model · sizeof(dtype).

The "2" is one slot for K, one for V. Independent of n_heads once expressed as n_heads · d_h = d_model.

For LLaMA-7B (n_layers=32, d_model=4096) in fp16:

\[ \text{bytes/token} = 2 \cdot 32 \cdot 4096 \cdot 2 = 524{,}288 \text{ bytes} = 512 \text{ KiB / token} \]

A N=4096 context: 4096 · 512 KiB = 2 GiB just for the KV cache. This is why long-context inference is hard.

GQA: group queries, share K and V

GQA partitions the n_heads heads into n_groups = n_heads / kv_heads groups, where each group of query heads shares one K, V head.

KV cache shape: (n_layers, 2, kv_heads, N, d_h).

Bytes per token:

\[ \text{bytes/token}_{\text{GQA}} = 2 \cdot n_{\text{layers}} \cdot \underbrace{k_{\text{KV}} \cdot d_h}_{\text{KV width}} \cdot \text{sizeof(dtype)} \]

The savings ratio:

\[ \frac{\text{bytes}_{\text{GQA}}}{\text{bytes}_{\text{MHA}}} = \frac{k_{\text{KV}}}{n_{\text{heads}}} \]

For LLaMA-2-7B (n_heads=32, kv_heads=32) → MHA, no savings. For LLaMA-2-70B (n_heads=64, kv_heads=8) → GQA-8, 8× reduction. KV/token drops from 1280 KiB (MHA-equivalent) to 160 KiB. For Mistral-7B (n_heads=32, kv_heads=8) → GQA-8, 4× reduction.

Why GQA is not free in quality

The query heads in a group share one K, V — so they can only attend to the same locations, just with different "weights" via the Q projection. This is a real expressivity constraint. The empirical finding (Ainslie et al., 2023): on a well-tuned model, kv_heads = n_heads / 8 is essentially indistinguishable in quality from full MHA. Pushed further (kv_heads = 1 = MQA), the quality hit is real — most production models stop at n_heads / 8.

The cache-memory frontier

For a model serving B concurrent users at context N:

\[ \text{total KV cache} = B \cdot N \cdot 2 \cdot n_{\text{layers}} \cdot k_{\text{KV}} \cdot d_h \cdot \text{sizeof(dtype)} \]

This is what GQA buys you: at fixed (B, N, n_layers, d_h), going from kv_heads = n_heads to kv_heads = n_heads / 8 lets you serve 8× more users in the same KV-cache budget. Or serve the same users at 8× longer context. Or use 8× cheaper hardware.

The KV cache is the practical bottleneck at production inference scale. Attention FLOPs scale with ; KV-cache memory scales with N. As models go to 100K-token contexts, KV-cache memory dominates everything else. GQA is the single biggest architectural lever to manage that, short of dropping caching entirely (LoRA, infinite-context tricks, etc.).

Mini-GPT-scale numbers (so the derivation is not abstract)

For Mini-GPT (n_layers=2, n_heads=4, d_h=16, d_model=64, fp32) at N=64:

Variant kv_heads KV bytes/token Total KV bytes at N=64
MHA 4 2·2·4·16·4 = 1024 65 536 (64 KiB)
GQA-2 2 2·2·2·16·4 = 512 32 768 (32 KiB)
MQA 1 2·2·1·16·4 = 256 16 384 (16 KiB)

A 4× reduction from MHA to MQA — small absolute, but it's the same ratio that scales to a gigabyte for LLaMA-2-70B.

Citations

  • Dao, Fu, Ermon, Rudra, Ré. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. arXiv:2205.14135.
  • Dao. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. 2023. arXiv:2307.08691.
  • Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón, Sanghai. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. arXiv:2305.13245.

One-paragraph recap

The Flash forward is a double loop: outer Q tile, inner K, V tile, with the online-softmax recurrence in the inner step that keeps the partial logits matrix S SRAM-resident and avoids HBM round-trips. The HBM traffic drops from O(N²) to O(N²d / B_r) — half-to-third in practice — and S never crosses HBM at all. GQA, in parallel, attacks the KV-cache memory: by sharing K, V across groups of query heads, it cuts KV bytes/token by exactly kv_heads / n_heads. The two techniques compose: Flash speeds up attention; GQA shrinks the cache. Together they make long-context inference (10K-100K tokens) economically feasible.

Next: lab/04-mqa-gqa.md for the empirical measurement.