Skip to content

English · Español

Lab 02 — Flash Attention Forward in Triton

Goal: implement an educational Flash Attention forward kernel in Triton. Verify it matches PyTorch's reference attention to 1e-3 at fp16.

Estimated time: 12–20 hours (this is by far the largest lab in Phase 27).

Prereq: theory 01 and 02 internalized; labs 00 and 01 committed; Phase 24's Triton vector-add kernel re-skimmed; cloud GPU access.


What you produce

src/minimodel/attention_flash.py — the Triton kernel and a Python wrapper (extends src/minimodel/attention.py from Phase 15; this lab does NOT create a new top-level module).

experiments/27-flash-attn-triton/ containing:

  • bench.py — runs the kernel against PyTorch reference attention; reports max-abs-error and (optionally) wall-clock speedup.
  • results.json — accuracy and timing.
  • manifest.json.
  • README.md — interpretation; commentary on what the kernel taught you about Triton.

Tests in tests/test_flash_attn.py (Claude scaffolds failing).

The kernel structure

(See src/minimodel/README.md (extended in Phase 27) for the full API. Brief here.)

@triton.jit
def flash_attn_fwd(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    L_ptr,                       # output: log-sum-exp per row (for backward; we store but don't use)
    sm_scale,
    stride_q_b, stride_q_h, stride_q_n, stride_q_d,   # strides
    stride_k_b, stride_k_h, stride_k_n, stride_k_d,
    stride_v_b, stride_v_h, stride_v_n, stride_v_d,
    stride_o_b, stride_o_h, stride_o_n, stride_o_d,
    N: tl.constexpr, d: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,    # B_r and B_c from theory
    IS_CAUSAL: tl.constexpr,
):
    # one program instance per (batch, head, query-tile)
    pid_bh = tl.program_id(0)        # batch × head index
    pid_m  = tl.program_id(1)        # query tile index
    ...

TODOs

Block A — set up the wrapper

  • src/minimodel/attention_flash.py has a flash_attn_forward(Q, K, V, causal=True) -> O Python function.
  • Q, K, V are (B, H, N, d) tensors in fp16 on CUDA.
  • The wrapper picks BLOCK_M=64, BLOCK_N=64 for d=64 (or appropriate for the head dim). It validates shapes and dtypes, then launches the Triton kernel.

Block B — implement the kernel

Skeleton (you fill in the body):

# Load Q tile into SRAM
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, d)
q_ptrs = Q_ptr + bh_off + offs_m[:, None] * stride_q_n + offs_d[None, :] * stride_q_d
q = tl.load(q_ptrs)

# Initialize accumulators
o = tl.zeros([BLOCK_M, d], dtype=tl.float32)
m = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
ell = tl.zeros([BLOCK_M], dtype=tl.float32)

# Inner loop over K, V tiles
for start_n in range(0, N, BLOCK_N):
    offs_n = start_n + tl.arange(0, BLOCK_N)
    # Load K tile
    k = tl.load(...)
    # Load V tile
    v = tl.load(...)
    # Compute S_ij = q @ k^T * sm_scale
    s = tl.dot(q, tl.trans(k)) * sm_scale
    # Apply causal mask if needed
    if IS_CAUSAL:
        s = tl.where(offs_m[:, None] >= offs_n[None, :], s, -float('inf'))
    # Online softmax update
    m_new = tl.maximum(m, tl.max(s, 1))
    alpha = tl.exp(m - m_new)
    p = tl.exp(s - m_new[:, None])
    ell = alpha * ell + tl.sum(p, 1)
    o = alpha[:, None] * o + tl.dot(p.to(v.dtype), v)
    m = m_new

# Normalize and store
o = o / ell[:, None]
tl.store(o_ptrs, o.to(tl.float16))
tl.store(L_ptrs, m + tl.log(ell))   # log-sum-exp (for backward; unused this lab)

You fill in pointer arithmetic, masking for non-tile-aligned tails, fp32 accumulators where needed.

Block C — implement reference

reference_attn(Q, K, V, causal=True):

sm_scale = 1.0 / math.sqrt(d)
S = (Q @ K.transpose(-1, -2)) * sm_scale
if causal:
    S = S.masked_fill(causal_mask, float('-inf'))
P = torch.softmax(S, dim=-1)
return P @ V

This is the "naive" attention in PyTorch — slow but obviously correct.

Block D — verify correctness

  • Random (B, H, N, d) = (2, 4, 1024, 64) fp16 on CUDA.
  • Compute O_ref = reference_attn(Q, K, V).
  • Compute O_flash = flash_attn_forward(Q, K, V).
  • max_abs_error = (O_ref - O_flash).abs().max(). Assert < 1e-3.
  • Repeat for N ∈ {256, 1024, 4096} and causal ∈ {True, False}.

Block E — measure speedup (optional but encouraged)

  • Wall-clock both with torch.cuda.synchronize() and time.perf_counter().
  • Report tokens/sec or matmul-FLOPs/sec.
  • Expected: 2–6× speedup at N=4096 fp16 on consumer GPU (3090 / 4090 / A10). On Hopper (H100), maybe more.

Block F — interpret in README.md

Four questions:

  1. What max-abs-error did you achieve? Did the result stay under 1e-3 for all N you tested?
  2. Where did Triton trip you up? Pointer arithmetic? tl.dot shapes? Causal masking? Be specific — this is the bulk of the README.
  3. Compared to lab 01's prediction, what speedup did you measure? If lower, why? (SRAM saturation, kernel launch overhead, autotune not invoked, etc.)
  4. What would you change for Flash 2 (instead of Flash 1 as written above)? (Hint: swap which loop is outer — Q outer instead of KV outer. Why does this matter for tensor cores?)

Constraints

  • fp16 only. bf16 is an optional add-on; fp32 path adds complexity.
  • One head dim per kernel. Don't try to make d a runtime variable; use tl.constexpr and recompile per d.
  • No autotune in the first pass. Pick BLOCK_M, BLOCK_N = 64, 64 and ship. Adding @triton.autotune is a polish step.
  • No backward. Forward only this phase.

Stop conditions

  • Kernel implemented in src/minimodel/attention_flash.py.
  • Tests pass; max-abs-error < 1e-3 at fp16 for N ∈ {64, 256, 1024, 4096} causal and non-causal (N=64 is the verb-corpus sequence length).
  • README answers all four questions.
  • All five experiment files committed.

Pitfalls

  • NaN in output. Most common cause: softmax of a row of all -inf (causal mask hides everything). Triton's tl.exp(-inf) is 0, so ℓ = 0, and you divide by zero. Add ℓ = tl.maximum(ℓ, 1e-30) before normalization. Real Flash kernels handle this with care; for the educational version, the guard is fine.
  • Wrong layout/strides. PyTorch's (B, H, N, d) layout has strides that depend on whether the tensor is contiguous. Always Q = Q.contiguous() before passing in, and recompute strides each call.
  • tl.dot shape mismatch. tl.dot requires both operands' shapes match Triton's tensor-core convention (typically (BLOCK_M, d) @ (d, BLOCK_N)). If you pass tl.trans(k) it should be (d, BLOCK_N) — verify.
  • Causal mask off-by-one. The mask uses >=, not >: position i attends to itself.
  • fp16 accumulators overflow. Even after subtracting the max, can grow large. Always accumulate m, ℓ, o in fp32 (cast only for output).

When to consult solutions/

After tests pass and the DoD threshold is met. The reference at solutions/02-flash-triton-ref.md (phase open) goes through the pointer arithmetic line-by-line.


Next lab: lab/03-paged-attn-reading.md.