English · Español
Lab 01 — The Bytes-Moved Delta of Flash vs Naive¶
Goal: derive symbolically and measure empirically the bytes moved by naive vs Flash attention. Plot bytes vs N. Compute the roofline intensity ratio.
Estimated time: 3–4 hours.
Prereq: theory 02 read; online softmax lab committed.
What you produce¶
A directory experiments/27-flash-vs-naive-bytes/ containing:
derive_bytes.py— closed-form bytes-moved formulas (just printed; this is a worked-derivation script, not a heavy compute job).measure_bytes.py— empirical bytes-moved using a profiler or instrumented reference implementation.bytes_vs_n.png— log-log plot of bytes-moved vs N for both schemes.intensity_ratio.png— derived intensity ratio Flash/naive vs N.manifest.json.README.md.
TODOs¶
Block A — derive symbolic bytes-moved¶
Implement Python functions returning closed-form byte counts:
-
bytes_naive(N, d, dtype_bytes=2)→ returns(12*N*d + 16*N**2) * (dtype_bytes / 4)(the formula from theory 02; scale by dtype). Note: this is HBM bytes; ignoresOwrite (same in both). -
bytes_flash(N, d, B_r, B_c, dtype_bytes=2)→ returns8 * N * d * (1 + N / B_r) * (dtype_bytes / 4). - Print both for several
(N, d)combinations:(64, 64)(the verb-corpus sequence length — Flash's win should be ~zero here),(1024, 64), (2048, 64), (4096, 64), (8192, 128), (32768, 128). - Note in the script output: at
N=64, the naiveN²=4096materialized matrix fits in 16 KiB at fp32 — inside L1. Flash buys nothing. The whole point of running this for the verb sequence is to show that Flash is only a win once N · N · dtype > L2.
Block B — symbolic intensity¶
- FLOPs =
4 * N * N * d(the dominantQ@K^T+P@Vcost). - Intensity for both. Print as a table.
Block C — measure empirically (CPU)¶
Since Borja's local hardware is CPU, don't try to measure on GPU here; that's lab 02's job. Use a CPU implementation:
- Implement
attn_naive_cpu(Q, K, V)in PyTorch using explicit matmuls and softmax. Usetorch.profilerto count bytes moved. - Implement
attn_flash_reference_cpu(Q, K, V, B_r, B_c)in pure PyTorch (no Triton — just the tiling loop). Profile bytes. - Compare measured bytes to symbolic predictions. Should agree within ~30% (overhead from Python, layout, etc.).
Block D — plot¶
- Log-log plot, x-axis N from 256 to 16384 doubling, y-axis bytes moved. Two lines (naive, Flash with
B_r=B_c=64). - Annotate where Flash becomes "much better" than naive (typically N ≥ 1024).
Block E — roofline overlay (preview for lab 02)¶
- Compute the intensity for both at
N=2048, d=64on Borja's machine. Use Phase 1's measured roofline ceilings. - Predict the speedup ratio. Save as a note in
README.mdto compare against lab 02's measured speedup on GPU.
Block F — interpret in README.md¶
Three questions:
- At what N does Flash start to win in HBM bytes-moved? Below some N, the tile-overhead may dominate. Plot should show a crossover.
- By the symbolic formula, what's the intensity ratio Flash/naive for
N=8192, d=128, B_r=B_c=64? Show your work. - The Flash paper claims 3× speedup on A100 at N=2048. By your intensity ratio, the theoretical roofline-limited speedup might be 5–10×. Why is the realized speedup lower? (Hint: SRAM bandwidth saturation, non-matmul FLOPs, kernel launch overhead.)
Stop conditions¶
- All six files committed.
- Symbolic and measured bytes agree to within 30%.
- Plots are committed and labelled.
- README answers all three Block F questions.
Pitfalls¶
- Bytes accounting includes
Owrite. Both schemes writeOonce; this is the same. Subtract it out for clean comparison. - fp16 vs fp32 confusion. Naive PyTorch's softmax internally casts to fp32 then back. The bytes-moved measurement should reflect this (fp32 intermediate dominates).
- Reference Flash on CPU is slow. That's OK; you're measuring bytes, not seconds. Use small N (≤ 512) if it's painfully slow.
Connection to lab 02¶
This lab gives you the prediction. Lab 02 (Triton kernel on GPU) gives you the measurement. Save your intensity ratio prediction here; check it there.
When to consult solutions/¶
After all six files committed and predictions agree with the measured intensity ratio to within 20%.
Next lab: lab/02-flash-triton.md.