English · Español
02 — Flash Attention as a Roofline Optimization¶
🇪🇸 Flash Attention es Fase 1 en acción. Mismo número de FLOPs que la atención naive; menos bytes movidos a HBM porque la matriz
S=QKᵀnunca se materializa. La intensidad aritmética sube; el punto del roofline se mueve hacia el techo de cómputo. No es un truco — es álgebra (recurrencia online) más un layout de tiles.
This file is the centrepiece of Phase 27. Read it once, then re-read it with theory/01-online-softmax.md open in another tab. By the end you should be able to (a) draw the tile-by-tile execution, (b) symbolically derive the byte-count delta, © state the roofline shift.
The naive attention algorithm¶
For one head with Q, K, V ∈ ℝ^{N × d}, output O ∈ ℝ^{N × d}:
1. S = Q @ K^T # (N, N) — materialized in HBM
2. P = softmax_rowwise(S) # (N, N) — materialized in HBM
3. O = P @ V # (N, d)
HBM traffic accounting (read = R, write = W; ignore O since it's the same in both algorithms):
- Step 1: R(
Q) + R(K) + W(S) =Nd + Nd + N²reads/writes (fp32 → ×4 bytes). - Step 2: R(
S) + W(P) =N² + N²=2N². - Step 3: R(
P) + R(V) =N² + Nd.
Total: Nd × 3 + N² × 4 fp32 elements = (12 Nd + 16 N²) bytes.
For N=2048, d=64: 12 × 2048 × 64 + 16 × 2048² = 1.6 MiB + 64 MiB = 65.6 MiB.
FLOPs: 2 × N² × d (Q@K^T) + 5 × N² (softmax) + 2 × N² × d (P@V) = 4 N² d + 5 N². For our numbers: 4 × 2048² × 64 + 5 × 2048² = 1.07 GF + 21 MF = 1.09 GFLOPs.
Intensity: 1.09e9 / 6.88e7 = 15.8 FLOPs/byte. For an A100 with I_crit ≈ 200, this is 13× below the corner. Memory-bound.
The Flash algorithm (forward)¶
Pick tile sizes:
- B_r — rows of Q per outer tile.
- B_c — rows of K (and V) per inner tile.
Constraint: one tile's worth of intermediate state must fit in SRAM. The per-tile state is S_tile ∈ ℝ^{B_r × B_c} (the partial inner-product block), plus per-B_r-row m, ℓ vectors (scalar per query row). For SRAM budget M_sram:
For d=64, M_sram=64 KiB, B_r = B_c = 64 works: 64×64×4 + 3×64×64×4 = 16 KiB + 48 KiB = 64 KiB. ✓
The algorithm:
for i = 0 .. (N / B_r) - 1: # outer loop: Q tiles
Q_i = Q[i*B_r:(i+1)*B_r, :] # (B_r, d), load to SRAM
O_i = zeros(B_r, d) # accumulator in SRAM
m_i = full(B_r, -inf) # running max per row
ℓ_i = zeros(B_r) # running sum per row
for j = 0 .. (N / B_c) - 1: # inner loop: K, V tiles
K_j = K[j*B_c:(j+1)*B_c, :] # (B_c, d), load to SRAM
V_j = V[j*B_c:(j+1)*B_c, :] # (B_c, d), load to SRAM
S_ij = Q_i @ K_j^T / sqrt(d) # (B_r, B_c), in SRAM
m_new = max(m_i, rowmax(S_ij))
α = exp(m_i - m_new) # (B_r,)
P_ij = exp(S_ij - m_new[:, None]) # (B_r, B_c)
ℓ_i = α * ℓ_i + rowsum(P_ij)
O_i = α[:, None] * O_i + P_ij @ V_j # the online update
m_i = m_new
O[i*B_r:(i+1)*B_r, :] = O_i / ℓ_i[:, None] # normalize, write to HBM
The online softmax recurrence from theory 01 is exactly what's inside the inner loop. The tiling is what wraps it.
Bytes moved by Flash¶
The big change: S never crosses HBM. It lives only in SRAM, computed and consumed per (i, j) inner step.
HBM traffic:
Qis read once per outer tile, totalNdelements.K,Vare each readN / B_rtimes (once per outer iteration), total2 × Nd × N/B_relements.Ois written once at the end of each outer iteration, totalNd.m, ℓare negligible (O(N)total, notN²).
Total: Nd × (2 + 2N/B_r) fp32 elements = (8 Nd × (1 + N/B_r)) bytes.
For N=2048, d=64, B_r=64: 8 × 2048 × 64 × (1 + 32) = 8 × 2048 × 64 × 33 ≈ 33 MiB.
Compare to naive's 65.6 MiB. That's 2× less.
Wait — only 2×? The "3× faster" claim implies more.
Two answers:
- The bytes accounting here is generous to naive. A real PyTorch implementation also computes
Sin fp32 even with fp16 Q, K (because of softmax stability), then casts back. Real bytes moved is closer to~24 N²(3× our 16 N² estimate). - The roofline picture matters more than the byte count. Even if bytes were equal, Flash's tiles fit in SRAM. The relevant ceiling for Flash isn't HBM bandwidth — it's SRAM bandwidth (on A100, ~19 TB/s, 12× higher than HBM). The "memory ceiling" for Flash kernels is a higher line. Same FLOPs / fewer effective bytes (counting against the SRAM ceiling) = much higher intensity.
The cleanest one-liner: Flash trades HBM bandwidth for SRAM bandwidth. The total bytes moved per kernel might be similar, but the bytes that cross the slow boundary (HBM ↔ SRAM) are much fewer.
Re-stating against the Phase 1 roofline¶
From docs/phase-01-hardware-substrate/theory/03-roofline-model.md, the roofline equation is perf = min(π, I × β). Two regimes: memory-bound (below I_crit = π/β) and compute-bound (above).
For naive attention on A100: I ≈ 16 FLOPs/byte, far below I_crit ≈ 200. Performance ceiling: 16 × 1.55 TB/s = 25 TFLOPS. Out of 312 peak — 8% utilization.
For Flash on A100: HBM bytes moved drops; the relevant β if we account only HBM traffic gives I_effective ≈ 100+ FLOPs/byte. Performance ceiling: 100 × 1.55 TB/s = 155 TFLOPS. Half of peak.
The dot moved 6× up the slope. That's the "3× faster" claim, re-derived from first principles. (It's 6× on the roofline, but in practice the realized speedup is smaller because the kernel can't perfectly saturate SRAM bandwidth and has other overheads.)
This is the roofline argument Borja should have at fingertips. Quantization (Phase 26) cuts bytes by reducing per-element size. Flash cuts bytes by avoiding intermediate materialization. Both push the dot up.
Why Flash is exact, not approximate¶
A common misconception: "Flash is an approximation because it processes tiles." False.
The online softmax recurrence (theory 01) is an identity, not an approximation. Each tile-by-tile update produces the same final O/ℓ as the all-at-once softmax up to floating-point round-off. The round-off is no worse than naive — in fact often slightly better, because Flash's running rescaling tends to keep numbers in a tight range.
Empirically: O_flash - O_naive has max abs error ~1e-6 at fp32, ~1e-3 at fp16, on standard test inputs. This is the same order as the round-off Naive itself accumulates.
This is what the DoD threshold (1e-3 at fp16) checks in lab 02.
What Flash forward doesn't do¶
- Doesn't help training memory. Storing
Swas a memory cost (N²per layer per batch). Flash avoids that storage. But for backward, we need to recomputeSfromQ, K— saving in-memory storage at the cost of FLOPs. Out of scope for this phase (forward only). - Doesn't accelerate softmax-free attention. Linear attention, kernel attention, etc., don't have a softmax — Flash's mechanism doesn't directly apply.
- Doesn't help short sequences. For
N ≤ 256, the(N,N)matrix fits in SRAM trivially. Flash's tiling overhead can outweigh the win. The PyTorch heuristic to use Flash only forN ≥ 512reflects this. - Doesn't optimize attention with very large head dim. For
d=128, the per-tile SRAM budget gets tight;B_r, B_cmust shrink, reducing arithmetic intensity within tiles. Flash 2 (the follow-up paper) re-balances the tile dims for larged.
A note on Flash 1 vs Flash 2¶
The original Flash Attention paper (Dao et al., 2022) had an outer loop over K, V tiles and an inner loop over Q tiles — opposite of what we wrote above. Flash 2 (Dao, 2023) swapped these because Q-outer reduces non-matmul FLOPs and better fits Hopper's tensor cores.
For Phase 27 we implement Flash 2 (Q-outer), but the algebraic content is identical. The kernel code differs only in which loop is outer. Lab 02 specifies Flash 2.
Drill problems¶
Solutions at phase open in solutions/02-flash-attention-ref.md. Reason, don't run.
- Compute the HBM bytes moved by Flash for
N=8192, d=128, B_r=64, B_c=64in fp16. Compare to naive's HBM bytes. State the speedup ratio (purely from bytes). - The SRAM budget on Hopper H100 is ~228 KB per SM. Pick
B_r, B_cford=128fp16 such that 4 tiles (Q, K, V, S) fit. What's the maximumB_r × B_cyou can afford? - Sliding-window attention with window
W=512on a sequence ofN=8192. How does the inner loop change? How many tiles ofK, Vdoes eachQtile need to read? Compare to dense Flash. - Show that for
B_c → 1(one K/V row per inner tile), Flash degenerates into computing softmax serially overNterms with running max/sum. Why is this not useful? (Hint: think tensor-core utilization.)
One-paragraph recap¶
Flash Attention tiles Q, K, V into SRAM-resident blocks and uses the online softmax recurrence to avoid materializing the (N, N) matrix S = QKᵀ in HBM. The FLOPs are identical to naive attention. The bytes that cross the slow HBM↔SRAM boundary are 2–10× fewer (depending on tile sizes and head dim), and the working set per tile lives on a much faster (SRAM) bandwidth ceiling. On the Phase 1 roofline, this translates to a dot that's much closer to the compute ceiling — the source of Flash's 3–10× wall-clock speedup. The algorithm is exact up to floating-point round-off; it's not an approximation. The next theory file extends the byte-count framing to PagedAttention (a different layer of the stack — KV cache, not the kernel itself).
Next: theory/03-paged-and-sliding.md.