English · Español
Lab 02 — Forward de flash attention en Triton¶
Objetivo: implementar un kernel educativo de forward de flash attention en Triton. Verificar que coincide con la attention de referencia de PyTorch a
1e-3en fp16.Tiempo estimado: 12–20 horas (es con diferencia el lab más grande de la Fase 27).
Prerrequisito: theory 01 y 02 interiorizadas; labs 00 y 01 commiteados; el kernel vector-add de Triton de la Fase 24 reojeado; acceso a GPU en cloud.
Lo que produces¶
src/minimodel/attention_flash.py — el kernel Triton y un wrapper Python (extiende src/minimodel/attention.py de la Fase 15; este lab NO crea un nuevo módulo de primer nivel).
experiments/27-flash-attn-triton/ que contenga:
bench.py— ejecuta el kernel contra la attention de referencia de PyTorch; reporta max-abs-error y (opcionalmente) speedup wall-clock.results.json— precisión y tiempos.manifest.json.README.md— interpretación; comentario sobre lo que el kernel te enseñó de Triton.
Tests en tests/test_flash_attn.py (Claude scaffoldea los que fallan).
La estructura del kernel¶
(Ver src/minimodel/README.md (extendido en la Fase 27) para la API completa. Aquí en breve.)
@triton.jit
def flash_attn_fwd(
Q_ptr, K_ptr, V_ptr, O_ptr,
L_ptr, # output: log-sum-exp per row (for backward; we store but don't use)
sm_scale,
stride_q_b, stride_q_h, stride_q_n, stride_q_d, # strides
stride_k_b, stride_k_h, stride_k_n, stride_k_d,
stride_v_b, stride_v_h, stride_v_n, stride_v_d,
stride_o_b, stride_o_h, stride_o_n, stride_o_d,
N: tl.constexpr, d: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # B_r and B_c from theory
IS_CAUSAL: tl.constexpr,
):
# one program instance per (batch, head, query-tile)
pid_bh = tl.program_id(0) # batch × head index
pid_m = tl.program_id(1) # query tile index
...
TODOs¶
Bloque A — montar el wrapper¶
-
src/minimodel/attention_flash.pytiene una función Pythonflash_attn_forward(Q, K, V, causal=True) -> O. - Q, K, V son tensores
(B, H, N, d)en fp16 sobre CUDA. - El wrapper elige
BLOCK_M=64, BLOCK_N=64parad=64(o el apropiado para la dimensión de cabeza). Valida formas y dtypes, luego lanza el kernel Triton.
Bloque B — implementar el kernel¶
Esqueleto (rellena el cuerpo):
# Load Q tile into SRAM
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, d)
q_ptrs = Q_ptr + bh_off + offs_m[:, None] * stride_q_n + offs_d[None, :] * stride_q_d
q = tl.load(q_ptrs)
# Initialize accumulators
o = tl.zeros([BLOCK_M, d], dtype=tl.float32)
m = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
ell = tl.zeros([BLOCK_M], dtype=tl.float32)
# Inner loop over K, V tiles
for start_n in range(0, N, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
# Load K tile
k = tl.load(...)
# Load V tile
v = tl.load(...)
# Compute S_ij = q @ k^T * sm_scale
s = tl.dot(q, tl.trans(k)) * sm_scale
# Apply causal mask if needed
if IS_CAUSAL:
s = tl.where(offs_m[:, None] >= offs_n[None, :], s, -float('inf'))
# Online softmax update
m_new = tl.maximum(m, tl.max(s, 1))
alpha = tl.exp(m - m_new)
p = tl.exp(s - m_new[:, None])
ell = alpha * ell + tl.sum(p, 1)
o = alpha[:, None] * o + tl.dot(p.to(v.dtype), v)
m = m_new
# Normalize and store
o = o / ell[:, None]
tl.store(o_ptrs, o.to(tl.float16))
tl.store(L_ptrs, m + tl.log(ell)) # log-sum-exp (for backward; unused this lab)
Tú rellenas la aritmética de punteros, el enmascarado para colas no alineadas y los acumuladores fp32 donde haga falta.
Bloque C — implementar la referencia¶
reference_attn(Q, K, V, causal=True):
sm_scale = 1.0 / math.sqrt(d)
S = (Q @ K.transpose(-1, -2)) * sm_scale
if causal:
S = S.masked_fill(causal_mask, float('-inf'))
P = torch.softmax(S, dim=-1)
return P @ V
Esta es la attention "naive" en PyTorch — lenta pero obviamente correcta.
Bloque D — verificar la corrección¶
- Aleatorio
(B, H, N, d) = (2, 4, 1024, 64)fp16 en CUDA. - Calcula
O_ref = reference_attn(Q, K, V). - Calcula
O_flash = flash_attn_forward(Q, K, V). -
max_abs_error = (O_ref - O_flash).abs().max(). Asegura <1e-3. - Repite para
N ∈ {256, 1024, 4096}y causal ∈ {True, False}.
Bloque E — medir el speedup (opcional pero recomendado)¶
- Wall-clock con
torch.cuda.synchronize()ytime.perf_counter(). - Reporta tokens/s o matmul-FLOPs/s.
- Esperado: 2–6× de speedup a N=4096 fp16 en GPU de consumo (3090 / 4090 / A10). En Hopper (H100), quizá más.
Bloque F — interpretar en README.md¶
Cuatro preguntas:
- ¿Qué max-abs-error conseguiste? ¿Se mantuvo bajo
1e-3para todos los N que probaste? - ¿Dónde te tropezó Triton? ¿Aritmética de punteros? ¿Formas de
tl.dot? ¿Enmascarado causal? Sé específico — esta es la parte gruesa del README. - Comparado con la predicción del lab 01, ¿qué speedup mediste? Si fue menor, ¿por qué? (Saturación de SRAM, sobrecarga de lanzamiento del kernel, autotune sin invocar, etc.)
- ¿Qué cambiarías para flash 2 (en lugar del flash 1 escrito arriba)? (Pista: intercambiar qué bucle es exterior — Q exterior en lugar de KV exterior. ¿Por qué importa para los tensor cores?)
Restricciones¶
- Sólo fp16. bf16 es un añadido opcional; el camino fp32 añade complejidad.
- Una dimensión de cabeza por kernel. No intentes hacer
duna variable en tiempo de ejecución; usatl.constexpry recompila pord. - Sin autotune en la primera pasada. Elige
BLOCK_M, BLOCK_N = 64, 64y entrega. Añadir@triton.autotunees un paso de pulido. - Sin backward. Sólo forward esta fase.
Condiciones de parada¶
- Kernel implementado en
src/minimodel/attention_flash.py. - Los tests pasan; max-abs-error <
1e-3en fp16 paraN ∈ {64, 256, 1024, 4096}causal y no-causal (N=64 es la longitud de secuencia del corpus de verbos). - El README responde las cuatro preguntas.
- Los cinco archivos del experimento commiteados.
Errores típicos¶
- NaN en la salida. Causa más común: softmax de una fila de todos
-inf(la máscara causal lo oculta todo).tl.exp(-inf)en Triton es0, así queℓ = 0, y divides por cero. Añadeℓ = tl.maximum(ℓ, 1e-30)antes de normalizar. Los kernels flash reales manejan esto con cuidado; para la versión educativa, la guarda basta. - Layout/strides incorrectos. El layout
(B, H, N, d)de PyTorch tiene strides que dependen de si el tensor es contiguo. SiempreQ = Q.contiguous()antes de pasarlo, y recalcula los strides en cada llamada. - Forma de
tl.dotque no concuerda.tl.dotrequiere que las formas de ambos operandos sigan la convención de tensor-core de Triton (típicamente(BLOCK_M, d) @ (d, BLOCK_N)). Si pasastl.trans(k)debería ser(d, BLOCK_N)— verifícalo. - Máscara causal off-by-one. La máscara usa
>=, no>: la posiciónise atiende a sí misma. - Desbordamiento de acumuladores fp16. Incluso tras restar el máximo,
ℓpuede crecer. Acumula siemprem, ℓ, oen fp32 (haz cast sólo para la salida).
Cuándo consultar solutions/¶
Tras pasar los tests y cumplir el umbral del DoD. La referencia en solutions/02-flash-triton-ref.md (apertura de fase) recorre la aritmética de punteros línea a línea.
Siguiente lab: lab/03-paged-attn-reading.md.