English · Español
Lab 02 — Flash Attention Forward in Triton¶
Goal: implement an educational Flash Attention forward kernel in Triton. Verify it matches PyTorch's reference attention to
1e-3at fp16.Estimated time: 12–20 hours (this is by far the largest lab in Phase 27).
Prereq: theory 01 and 02 internalized; labs 00 and 01 committed; Phase 24's Triton vector-add kernel re-skimmed; cloud GPU access.
What you produce¶
src/minimodel/attention_flash.py — the Triton kernel and a Python wrapper (extends src/minimodel/attention.py from Phase 15; this lab does NOT create a new top-level module).
experiments/27-flash-attn-triton/ containing:
bench.py— runs the kernel against PyTorch reference attention; reports max-abs-error and (optionally) wall-clock speedup.results.json— accuracy and timing.manifest.json.README.md— interpretation; commentary on what the kernel taught you about Triton.
Tests in tests/test_flash_attn.py (Claude scaffolds failing).
The kernel structure¶
(See src/minimodel/README.md (extended in Phase 27) for the full API. Brief here.)
@triton.jit
def flash_attn_fwd(
Q_ptr, K_ptr, V_ptr, O_ptr,
L_ptr, # output: log-sum-exp per row (for backward; we store but don't use)
sm_scale,
stride_q_b, stride_q_h, stride_q_n, stride_q_d, # strides
stride_k_b, stride_k_h, stride_k_n, stride_k_d,
stride_v_b, stride_v_h, stride_v_n, stride_v_d,
stride_o_b, stride_o_h, stride_o_n, stride_o_d,
N: tl.constexpr, d: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # B_r and B_c from theory
IS_CAUSAL: tl.constexpr,
):
# one program instance per (batch, head, query-tile)
pid_bh = tl.program_id(0) # batch × head index
pid_m = tl.program_id(1) # query tile index
...
TODOs¶
Block A — set up the wrapper¶
-
src/minimodel/attention_flash.pyhas aflash_attn_forward(Q, K, V, causal=True) -> OPython function. - Q, K, V are
(B, H, N, d)tensors in fp16 on CUDA. - The wrapper picks
BLOCK_M=64, BLOCK_N=64ford=64(or appropriate for the head dim). It validates shapes and dtypes, then launches the Triton kernel.
Block B — implement the kernel¶
Skeleton (you fill in the body):
# Load Q tile into SRAM
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, d)
q_ptrs = Q_ptr + bh_off + offs_m[:, None] * stride_q_n + offs_d[None, :] * stride_q_d
q = tl.load(q_ptrs)
# Initialize accumulators
o = tl.zeros([BLOCK_M, d], dtype=tl.float32)
m = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
ell = tl.zeros([BLOCK_M], dtype=tl.float32)
# Inner loop over K, V tiles
for start_n in range(0, N, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
# Load K tile
k = tl.load(...)
# Load V tile
v = tl.load(...)
# Compute S_ij = q @ k^T * sm_scale
s = tl.dot(q, tl.trans(k)) * sm_scale
# Apply causal mask if needed
if IS_CAUSAL:
s = tl.where(offs_m[:, None] >= offs_n[None, :], s, -float('inf'))
# Online softmax update
m_new = tl.maximum(m, tl.max(s, 1))
alpha = tl.exp(m - m_new)
p = tl.exp(s - m_new[:, None])
ell = alpha * ell + tl.sum(p, 1)
o = alpha[:, None] * o + tl.dot(p.to(v.dtype), v)
m = m_new
# Normalize and store
o = o / ell[:, None]
tl.store(o_ptrs, o.to(tl.float16))
tl.store(L_ptrs, m + tl.log(ell)) # log-sum-exp (for backward; unused this lab)
You fill in pointer arithmetic, masking for non-tile-aligned tails, fp32 accumulators where needed.
Block C — implement reference¶
reference_attn(Q, K, V, causal=True):
sm_scale = 1.0 / math.sqrt(d)
S = (Q @ K.transpose(-1, -2)) * sm_scale
if causal:
S = S.masked_fill(causal_mask, float('-inf'))
P = torch.softmax(S, dim=-1)
return P @ V
This is the "naive" attention in PyTorch — slow but obviously correct.
Block D — verify correctness¶
- Random
(B, H, N, d) = (2, 4, 1024, 64)fp16 on CUDA. - Compute
O_ref = reference_attn(Q, K, V). - Compute
O_flash = flash_attn_forward(Q, K, V). -
max_abs_error = (O_ref - O_flash).abs().max(). Assert <1e-3. - Repeat for
N ∈ {256, 1024, 4096}and causal ∈ {True, False}.
Block E — measure speedup (optional but encouraged)¶
- Wall-clock both with
torch.cuda.synchronize()andtime.perf_counter(). - Report tokens/sec or matmul-FLOPs/sec.
- Expected: 2–6× speedup at N=4096 fp16 on consumer GPU (3090 / 4090 / A10). On Hopper (H100), maybe more.
Block F — interpret in README.md¶
Four questions:
- What max-abs-error did you achieve? Did the result stay under
1e-3for all N you tested? - Where did Triton trip you up? Pointer arithmetic?
tl.dotshapes? Causal masking? Be specific — this is the bulk of the README. - Compared to lab 01's prediction, what speedup did you measure? If lower, why? (SRAM saturation, kernel launch overhead, autotune not invoked, etc.)
- What would you change for Flash 2 (instead of Flash 1 as written above)? (Hint: swap which loop is outer — Q outer instead of KV outer. Why does this matter for tensor cores?)
Constraints¶
- fp16 only. bf16 is an optional add-on; fp32 path adds complexity.
- One head dim per kernel. Don't try to make
da runtime variable; usetl.constexprand recompile perd. - No autotune in the first pass. Pick
BLOCK_M, BLOCK_N = 64, 64and ship. Adding@triton.autotuneis a polish step. - No backward. Forward only this phase.
Stop conditions¶
- Kernel implemented in
src/minimodel/attention_flash.py. - Tests pass; max-abs-error <
1e-3at fp16 forN ∈ {64, 256, 1024, 4096}causal and non-causal (N=64 is the verb-corpus sequence length). - README answers all four questions.
- All five experiment files committed.
Pitfalls¶
- NaN in output. Most common cause: softmax of a row of all
-inf(causal mask hides everything). Triton'stl.exp(-inf)is0, soℓ = 0, and you divide by zero. Addℓ = tl.maximum(ℓ, 1e-30)before normalization. Real Flash kernels handle this with care; for the educational version, the guard is fine. - Wrong layout/strides. PyTorch's
(B, H, N, d)layout has strides that depend on whether the tensor is contiguous. AlwaysQ = Q.contiguous()before passing in, and recompute strides each call. tl.dotshape mismatch.tl.dotrequires both operands' shapes match Triton's tensor-core convention (typically(BLOCK_M, d) @ (d, BLOCK_N)). If you passtl.trans(k)it should be(d, BLOCK_N)— verify.- Causal mask off-by-one. The mask uses
>=, not>: positioniattends to itself. - fp16 accumulators overflow. Even after subtracting the max,
ℓcan grow large. Always accumulatem, ℓ, oin fp32 (cast only for output).
When to consult solutions/¶
After tests pass and the DoD threshold is met. The reference at solutions/02-flash-triton-ref.md (phase open) goes through the pointer arithmetic line-by-line.
Next lab: lab/03-paged-attn-reading.md.