Skip to content

English · Español

Lab 01 — The Bytes-Moved Delta of Flash vs Naive

Goal: derive symbolically and measure empirically the bytes moved by naive vs Flash attention. Plot bytes vs N. Compute the roofline intensity ratio.

Estimated time: 3–4 hours.

Prereq: theory 02 read; online softmax lab committed.


What you produce

A directory experiments/27-flash-vs-naive-bytes/ containing:

  • derive_bytes.py — closed-form bytes-moved formulas (just printed; this is a worked-derivation script, not a heavy compute job).
  • measure_bytes.py — empirical bytes-moved using a profiler or instrumented reference implementation.
  • bytes_vs_n.png — log-log plot of bytes-moved vs N for both schemes.
  • intensity_ratio.png — derived intensity ratio Flash/naive vs N.
  • manifest.json.
  • README.md.

TODOs

Block A — derive symbolic bytes-moved

Implement Python functions returning closed-form byte counts:

  • bytes_naive(N, d, dtype_bytes=2) → returns (12*N*d + 16*N**2) * (dtype_bytes / 4) (the formula from theory 02; scale by dtype). Note: this is HBM bytes; ignores O write (same in both).
  • bytes_flash(N, d, B_r, B_c, dtype_bytes=2) → returns 8 * N * d * (1 + N / B_r) * (dtype_bytes / 4).
  • Print both for several (N, d) combinations: (64, 64) (the verb-corpus sequence length — Flash's win should be ~zero here), (1024, 64), (2048, 64), (4096, 64), (8192, 128), (32768, 128).
  • Note in the script output: at N=64, the naive N²=4096 materialized matrix fits in 16 KiB at fp32 — inside L1. Flash buys nothing. The whole point of running this for the verb sequence is to show that Flash is only a win once N · N · dtype > L2.

Block B — symbolic intensity

  • FLOPs = 4 * N * N * d (the dominant Q@K^T + P@V cost).
  • Intensity for both. Print as a table.

Block C — measure empirically (CPU)

Since Borja's local hardware is CPU, don't try to measure on GPU here; that's lab 02's job. Use a CPU implementation:

  • Implement attn_naive_cpu(Q, K, V) in PyTorch using explicit matmuls and softmax. Use torch.profiler to count bytes moved.
  • Implement attn_flash_reference_cpu(Q, K, V, B_r, B_c) in pure PyTorch (no Triton — just the tiling loop). Profile bytes.
  • Compare measured bytes to symbolic predictions. Should agree within ~30% (overhead from Python, layout, etc.).

Block D — plot

  • Log-log plot, x-axis N from 256 to 16384 doubling, y-axis bytes moved. Two lines (naive, Flash with B_r=B_c=64).
  • Annotate where Flash becomes "much better" than naive (typically N ≥ 1024).

Block E — roofline overlay (preview for lab 02)

  • Compute the intensity for both at N=2048, d=64 on Borja's machine. Use Phase 1's measured roofline ceilings.
  • Predict the speedup ratio. Save as a note in README.md to compare against lab 02's measured speedup on GPU.

Block F — interpret in README.md

Three questions:

  1. At what N does Flash start to win in HBM bytes-moved? Below some N, the tile-overhead may dominate. Plot should show a crossover.
  2. By the symbolic formula, what's the intensity ratio Flash/naive for N=8192, d=128, B_r=B_c=64? Show your work.
  3. The Flash paper claims 3× speedup on A100 at N=2048. By your intensity ratio, the theoretical roofline-limited speedup might be 5–10×. Why is the realized speedup lower? (Hint: SRAM bandwidth saturation, non-matmul FLOPs, kernel launch overhead.)

Stop conditions

  • All six files committed.
  • Symbolic and measured bytes agree to within 30%.
  • Plots are committed and labelled.
  • README answers all three Block F questions.

Pitfalls

  • Bytes accounting includes O write. Both schemes write O once; this is the same. Subtract it out for clean comparison.
  • fp16 vs fp32 confusion. Naive PyTorch's softmax internally casts to fp32 then back. The bytes-moved measurement should reflect this (fp32 intermediate dominates).
  • Reference Flash on CPU is slow. That's OK; you're measuring bytes, not seconds. Use small N (≤ 512) if it's painfully slow.

Connection to lab 02

This lab gives you the prediction. Lab 02 (Triton kernel on GPU) gives you the measurement. Save your intensity ratio prediction here; check it there.

When to consult solutions/

After all six files committed and predictions agree with the measured intensity ratio to within 20%.


Next lab: lab/02-flash-triton.md.