Skip to content

English · Español

02 — Flash Attention as a Roofline Optimization

🇪🇸 Flash Attention es Fase 1 en acción. Mismo número de FLOPs que la atención naive; menos bytes movidos a HBM porque la matriz S=QKᵀ nunca se materializa. La intensidad aritmética sube; el punto del roofline se mueve hacia el techo de cómputo. No es un truco — es álgebra (recurrencia online) más un layout de tiles.

This file is the centrepiece of Phase 27. Read it once, then re-read it with theory/01-online-softmax.md open in another tab. By the end you should be able to (a) draw the tile-by-tile execution, (b) symbolically derive the byte-count delta, © state the roofline shift.


The naive attention algorithm

For one head with Q, K, V ∈ ℝ^{N × d}, output O ∈ ℝ^{N × d}:

1. S = Q @ K^T              # (N, N) — materialized in HBM
2. P = softmax_rowwise(S)   # (N, N) — materialized in HBM
3. O = P @ V                # (N, d)

HBM traffic accounting (read = R, write = W; ignore O since it's the same in both algorithms):

  • Step 1: R(Q) + R(K) + W(S) = Nd + Nd + N² reads/writes (fp32 → ×4 bytes).
  • Step 2: R(S) + W(P) = N² + N² = 2N².
  • Step 3: R(P) + R(V) = N² + Nd.

Total: Nd × 3 + N² × 4 fp32 elements = (12 Nd + 16 N²) bytes.

For N=2048, d=64: 12 × 2048 × 64 + 16 × 2048² = 1.6 MiB + 64 MiB = 65.6 MiB.

FLOPs: 2 × N² × d (Q@K^T) + 5 × N² (softmax) + 2 × N² × d (P@V) = 4 N² d + 5 N². For our numbers: 4 × 2048² × 64 + 5 × 2048² = 1.07 GF + 21 MF = 1.09 GFLOPs.

Intensity: 1.09e9 / 6.88e7 = 15.8 FLOPs/byte. For an A100 with I_crit ≈ 200, this is 13× below the corner. Memory-bound.

The Flash algorithm (forward)

Pick tile sizes: - B_r — rows of Q per outer tile. - B_c — rows of K (and V) per inner tile.

Constraint: one tile's worth of intermediate state must fit in SRAM. The per-tile state is S_tile ∈ ℝ^{B_r × B_c} (the partial inner-product block), plus per-B_r-row m, ℓ vectors (scalar per query row). For SRAM budget M_sram:

\[ B_r \times B_c \times 4 \, (\text{fp32}) + B_r \times d \times 4 \, (\text{Q tile}) + B_c \times d \times 4 \, (\text{K tile}) + B_c \times d \times 4 \, (\text{V tile}) \leq M_{\text{sram}} \]

For d=64, M_sram=64 KiB, B_r = B_c = 64 works: 64×64×4 + 3×64×64×4 = 16 KiB + 48 KiB = 64 KiB. ✓

The algorithm:

for i = 0 .. (N / B_r) - 1:                # outer loop: Q tiles
    Q_i = Q[i*B_r:(i+1)*B_r, :]           # (B_r, d), load to SRAM
    O_i = zeros(B_r, d)                    # accumulator in SRAM
    m_i = full(B_r, -inf)                  # running max per row
    ℓ_i = zeros(B_r)                       # running sum per row
    for j = 0 .. (N / B_c) - 1:           # inner loop: K, V tiles
        K_j = K[j*B_c:(j+1)*B_c, :]       # (B_c, d), load to SRAM
        V_j = V[j*B_c:(j+1)*B_c, :]       # (B_c, d), load to SRAM
        S_ij = Q_i @ K_j^T / sqrt(d)      # (B_r, B_c), in SRAM
        m_new = max(m_i, rowmax(S_ij))
        α = exp(m_i - m_new)               # (B_r,)
        P_ij = exp(S_ij - m_new[:, None])  # (B_r, B_c)
        ℓ_i = α * ℓ_i + rowsum(P_ij)
        O_i = α[:, None] * O_i + P_ij @ V_j   # the online update
        m_i = m_new
    O[i*B_r:(i+1)*B_r, :] = O_i / ℓ_i[:, None]   # normalize, write to HBM

The online softmax recurrence from theory 01 is exactly what's inside the inner loop. The tiling is what wraps it.

Bytes moved by Flash

The big change: S never crosses HBM. It lives only in SRAM, computed and consumed per (i, j) inner step.

HBM traffic:

  • Q is read once per outer tile, total Nd elements.
  • K, V are each read N / B_r times (once per outer iteration), total 2 × Nd × N/B_r elements.
  • O is written once at the end of each outer iteration, total Nd.
  • m, ℓ are negligible (O(N) total, not ).

Total: Nd × (2 + 2N/B_r) fp32 elements = (8 Nd × (1 + N/B_r)) bytes.

For N=2048, d=64, B_r=64: 8 × 2048 × 64 × (1 + 32) = 8 × 2048 × 64 × 33 ≈ 33 MiB.

Compare to naive's 65.6 MiB. That's 2× less.

Wait — only 2×? The "3× faster" claim implies more.

Two answers:

  1. The bytes accounting here is generous to naive. A real PyTorch implementation also computes S in fp32 even with fp16 Q, K (because of softmax stability), then casts back. Real bytes moved is closer to ~24 N² (3× our 16 N² estimate).
  2. The roofline picture matters more than the byte count. Even if bytes were equal, Flash's tiles fit in SRAM. The relevant ceiling for Flash isn't HBM bandwidth — it's SRAM bandwidth (on A100, ~19 TB/s, 12× higher than HBM). The "memory ceiling" for Flash kernels is a higher line. Same FLOPs / fewer effective bytes (counting against the SRAM ceiling) = much higher intensity.

The cleanest one-liner: Flash trades HBM bandwidth for SRAM bandwidth. The total bytes moved per kernel might be similar, but the bytes that cross the slow boundary (HBM ↔ SRAM) are much fewer.

Re-stating against the Phase 1 roofline

From docs/phase-01-hardware-substrate/theory/03-roofline-model.md, the roofline equation is perf = min(π, I × β). Two regimes: memory-bound (below I_crit = π/β) and compute-bound (above).

For naive attention on A100: I ≈ 16 FLOPs/byte, far below I_crit ≈ 200. Performance ceiling: 16 × 1.55 TB/s = 25 TFLOPS. Out of 312 peak — 8% utilization.

For Flash on A100: HBM bytes moved drops; the relevant β if we account only HBM traffic gives I_effective ≈ 100+ FLOPs/byte. Performance ceiling: 100 × 1.55 TB/s = 155 TFLOPS. Half of peak.

The dot moved 6× up the slope. That's the "3× faster" claim, re-derived from first principles. (It's 6× on the roofline, but in practice the realized speedup is smaller because the kernel can't perfectly saturate SRAM bandwidth and has other overheads.)

This is the roofline argument Borja should have at fingertips. Quantization (Phase 26) cuts bytes by reducing per-element size. Flash cuts bytes by avoiding intermediate materialization. Both push the dot up.

Why Flash is exact, not approximate

A common misconception: "Flash is an approximation because it processes tiles." False.

The online softmax recurrence (theory 01) is an identity, not an approximation. Each tile-by-tile update produces the same final O/ℓ as the all-at-once softmax up to floating-point round-off. The round-off is no worse than naive — in fact often slightly better, because Flash's running rescaling tends to keep numbers in a tight range.

Empirically: O_flash - O_naive has max abs error ~1e-6 at fp32, ~1e-3 at fp16, on standard test inputs. This is the same order as the round-off Naive itself accumulates.

This is what the DoD threshold (1e-3 at fp16) checks in lab 02.

What Flash forward doesn't do

  1. Doesn't help training memory. Storing S was a memory cost ( per layer per batch). Flash avoids that storage. But for backward, we need to recompute S from Q, K — saving in-memory storage at the cost of FLOPs. Out of scope for this phase (forward only).
  2. Doesn't accelerate softmax-free attention. Linear attention, kernel attention, etc., don't have a softmax — Flash's mechanism doesn't directly apply.
  3. Doesn't help short sequences. For N ≤ 256, the (N,N) matrix fits in SRAM trivially. Flash's tiling overhead can outweigh the win. The PyTorch heuristic to use Flash only for N ≥ 512 reflects this.
  4. Doesn't optimize attention with very large head dim. For d=128, the per-tile SRAM budget gets tight; B_r, B_c must shrink, reducing arithmetic intensity within tiles. Flash 2 (the follow-up paper) re-balances the tile dims for large d.

A note on Flash 1 vs Flash 2

The original Flash Attention paper (Dao et al., 2022) had an outer loop over K, V tiles and an inner loop over Q tiles — opposite of what we wrote above. Flash 2 (Dao, 2023) swapped these because Q-outer reduces non-matmul FLOPs and better fits Hopper's tensor cores.

For Phase 27 we implement Flash 2 (Q-outer), but the algebraic content is identical. The kernel code differs only in which loop is outer. Lab 02 specifies Flash 2.

Drill problems

Solutions at phase open in solutions/02-flash-attention-ref.md. Reason, don't run.

  1. Compute the HBM bytes moved by Flash for N=8192, d=128, B_r=64, B_c=64 in fp16. Compare to naive's HBM bytes. State the speedup ratio (purely from bytes).
  2. The SRAM budget on Hopper H100 is ~228 KB per SM. Pick B_r, B_c for d=128 fp16 such that 4 tiles (Q, K, V, S) fit. What's the maximum B_r × B_c you can afford?
  3. Sliding-window attention with window W=512 on a sequence of N=8192. How does the inner loop change? How many tiles of K, V does each Q tile need to read? Compare to dense Flash.
  4. Show that for B_c → 1 (one K/V row per inner tile), Flash degenerates into computing softmax serially over N terms with running max/sum. Why is this not useful? (Hint: think tensor-core utilization.)

One-paragraph recap

Flash Attention tiles Q, K, V into SRAM-resident blocks and uses the online softmax recurrence to avoid materializing the (N, N) matrix S = QKᵀ in HBM. The FLOPs are identical to naive attention. The bytes that cross the slow HBM↔SRAM boundary are 2–10× fewer (depending on tile sizes and head dim), and the working set per tile lives on a much faster (SRAM) bandwidth ceiling. On the Phase 1 roofline, this translates to a dot that's much closer to the compute ceiling — the source of Flash's 3–10× wall-clock speedup. The algorithm is exact up to floating-point round-off; it's not an approximation. The next theory file extends the byte-count framing to PagedAttention (a different layer of the stack — KV cache, not the kernel itself).

Next: theory/03-paged-and-sliding.md.