English · Español
00 — Why Attention Dominates Inference¶
🇪🇸 La atención no es el problema porque sea matemáticamente cara — lo es porque la matriz
S = QKᵀes enorme y se materializa en HBM. Flash no cambia la matemática; cambia qué bytes cruzan la barrera de memoria. Este archivo prepara el argumento del roofline para los siguientes tres.
The wall-clock breakdown of inference¶
A transformer forward step for one token, on a model with L layers, hidden size h, head dim d, and context length N, does roughly:
L×MLP: twoLinearoperations of size(h, 4h)and(4h, h). FLOPs per layer =2 × (2 × h × 4h) = 16 h². Bytes moved =2 × (h × 4h × 4) = 32 h²(fp32 weights).L×Attention: see below.L×LayerNorm: O(h) — negligible.
For a typical small model (L=12, h=768, d=64, N=2048):
- MLP FLOPs per layer =
16 × 768² ≈ 9.4M. AcrossL=12layers: ~113M. - Attention FLOPs per layer:
4 N² d ≈ 1.05G. Across 12 layers: ~12.6G.
Attention is 100× more compute than MLP per layer at this context length. And it scales N² while MLP scales N (for one new token; for full prefill it's N-ish for MLP, N² for attention regardless).
So attention dominates compute. But here's the kicker: it also dominates memory traffic, even harder.
Why attention is bandwidth-bound¶
Naive attention materializes the matrix S = Q K^T of shape (N, N). For N = 2048, that's 4 × 2048² = 16 MiB in fp32 (or 8 MiB in fp16). This matrix is:
- Written to HBM (after
Q K^T). - Read from HBM (for softmax).
- Written to HBM (the softmax-normalized matrix).
- Read from HBM (for
S @ V).
Total HBM traffic on S alone: 64 MiB per layer at N=2048, fp32. That's the size of an L2 cache; the GPU's L2 might be 40 MiB on an A100. We blow through it.
Compute the arithmetic intensity:
- FLOPs:
4 N² d = 4 × 2048² × 64 = 1.07G FLOPs. - Bytes moved (HBM):
N² × 4 (write S) + N² × 4 (read S for softmax) + N² × 4 (write softmax(S)) + N² × 4 (read for @V) + small terms for Q, K, V, O. Total ≈16 N²bytes ≈64 MiB. - Intensity:
1.07e9 / 6.7e7 ≈ 16 FLOPs/byte.
An A100's I_crit (compute-vs-HBM-bandwidth ratio): I_crit ≈ 312 TFLOPS / 1.55 TB/s ≈ 200 FLOPs/byte. Naive attention sits at ~16 FLOPs/byte — >10× below the corner. Almost all of the GPU is idle, waiting on HBM.
This is what Flash Attention solves. Not by changing the FLOPs (it doesn't), but by changing the bytes moved (it does).
What Flash actually does, in one paragraph¶
Flash Attention partitions Q, K, V into tiles. Inner tile dimensions are sized so that one tile's worth of intermediate state — S_tile of shape (B_r, B_c), plus running max and sum vectors — fits in on-chip SRAM (a few KiB to ~100 KiB depending on GPU). The big (N, N) matrix is never written to HBM. Only the per-output-tile (B_r, d) block and small (B_r,) running statistics flow between HBM and SRAM.
The mathematical trick that lets this work is the online softmax: a recurrence that lets you compute softmax(S) @ V incrementally as new S tiles arrive, without needing the full row of S first. Theory file 01 derives this; theory file 02 puts it inside the tiling loop.
The intensity gain: bytes moved drops from ~16 N² to roughly ~N · d · (3 + 2 N/B_c) (fp32). For N=2048, d=64, B_c=64: bytes ≈ 2048 × 64 × 67 × 4 ≈ 33 MiB. Intensity ≈ 1.07e9 / 3.5e7 ≈ 30 FLOPs/byte. ~2× higher, and the ratio grows with N. Still below the A100's 200 FLOPs/byte corner, but now we're on the steep part of the memory-ceiling slope, not at its foot.
This is why Flash is fast.
Re-stating the standard "3× speedup" claim¶
The Flash Attention paper reported ~3× wall-clock speedup over a tuned PyTorch baseline at N=2048 on A100. The number is hardware-and-setting-specific, but the mechanism is the byte-count reduction we just computed. "Flash is 3× faster" should always be read in the room of an engineer who knows it really means "Flash raises arithmetic intensity by keeping the working set in SRAM, so the kernel moves closer to the compute ceiling instead of crawling along the memory ceiling".
If Borja takes one phrase from this phase: the same FLOPs at higher intensity is the entire game. Quantization (Phase 26) attacks intensity from the byte side (smaller weights). Flash (this phase) attacks it from the algorithmic side (don't materialize what you don't need to). Both are roofline arguments; both are real.
PagedAttention is a different problem¶
PagedAttention (vLLM) is not a kernel optimization — it's a memory allocator optimization. The KV cache (storing all past K and V vectors for the autoregressive generation) is huge and grows per token. A long-context model with batch=32, N=8192, layers=32, heads=32, head_dim=128 has KV cache size = 2 × 32 × 8192 × 32 × 32 × 128 × 2 (fp16) ≈ 17 GiB. Allocating this contiguously per request leads to massive fragmentation across batch members — like a OS that does malloc and never free.
PagedAttention treats KV as virtual memory: small fixed-size blocks (e.g., 16 tokens worth of KV per page), a page table per request, copy-on-write for prefix caching. The attention kernel itself is modified to follow page-table indirections instead of accessing a flat K, V.
Where Flash attacks per-kernel HBM traffic, PagedAttention attacks cross-request memory utilization. These compose: a deployed inference server uses both.
We cover PagedAttention as a reading exercise (theory 03, lab 03) because re-implementing it is a server-engineering job that distracts from the kernel story. The annotated vLLM read is plenty.
Three other attention variants in this phase¶
- Sliding-window attention. Mistral et al. use a fixed-width context window: each token attends only to the last
W < Ntokens. Reduces complexity fromN²toN·W. Composable with Flash (the kernel just masks out positions outside the window). - Grouped/Multi-Query Attention (GQA/MQA). Share
K, Vacross multiple query heads. Reduces KV cache size byn_kv_groups / n_heads(typically 4×–8×). Doesn't reduce compute per token, but radically reduces per-token memory traffic during autoregressive decode. - Multi-Latent Attention (MLA, DeepSeek). Compresses K and V into a low-rank latent space; reconstructs K/V on the fly per attention call. Trades a small extra compute (the projection) for a much smaller KV cache.
All three are roofline arguments. Theory 04 walks through each.
What this phase doesn't try¶
We do not derive Flash backward. The backward pass uses recomputation: it re-derives S from Q and K on the fly during gradient computation, trading FLOPs for memory. The forward path's online softmax doesn't directly help — backward needs a different update rule. Out of scope for Phase 27; will return in a future phase.
One-paragraph recap¶
Attention dominates transformer inference both in FLOPs and in memory traffic. Naive attention sits ~50× below the GPU's roofline corner because materializing the (N,N) matrix S = QKᵀ blows through HBM. Flash Attention partitions Q/K/V into SRAM-resident tiles and uses an online softmax recurrence to avoid materializing S, cutting bytes moved by an order of magnitude and raising the dot 5–30× toward the compute ceiling. PagedAttention attacks a different bottleneck — KV cache fragmentation across requests in a server. GQA/MQA/MLA shrink the KV cache itself. The remaining theory files derive each idea; the labs implement Flash forward in Triton and read PagedAttention in vLLM.
Next: theory/01-online-softmax.md.