Skip to content

English · Español

Lab 02 — Forward de flash attention en Triton

Objetivo: implementar un kernel educativo de forward de flash attention en Triton. Verificar que coincide con la attention de referencia de PyTorch a 1e-3 en fp16.

Tiempo estimado: 12–20 horas (es con diferencia el lab más grande de la Fase 27).

Prerrequisito: theory 01 y 02 interiorizadas; labs 00 y 01 commiteados; el kernel vector-add de Triton de la Fase 24 reojeado; acceso a GPU en cloud.


Lo que produces

src/minimodel/attention_flash.py — el kernel Triton y un wrapper Python (extiende src/minimodel/attention.py de la Fase 15; este lab NO crea un nuevo módulo de primer nivel).

experiments/27-flash-attn-triton/ que contenga:

  • bench.py — ejecuta el kernel contra la attention de referencia de PyTorch; reporta max-abs-error y (opcionalmente) speedup wall-clock.
  • results.json — precisión y tiempos.
  • manifest.json.
  • README.md — interpretación; comentario sobre lo que el kernel te enseñó de Triton.

Tests en tests/test_flash_attn.py (Claude scaffoldea los que fallan).

La estructura del kernel

(Ver src/minimodel/README.md (extendido en la Fase 27) para la API completa. Aquí en breve.)

@triton.jit
def flash_attn_fwd(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    L_ptr,                       # output: log-sum-exp per row (for backward; we store but don't use)
    sm_scale,
    stride_q_b, stride_q_h, stride_q_n, stride_q_d,   # strides
    stride_k_b, stride_k_h, stride_k_n, stride_k_d,
    stride_v_b, stride_v_h, stride_v_n, stride_v_d,
    stride_o_b, stride_o_h, stride_o_n, stride_o_d,
    N: tl.constexpr, d: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,    # B_r and B_c from theory
    IS_CAUSAL: tl.constexpr,
):
    # one program instance per (batch, head, query-tile)
    pid_bh = tl.program_id(0)        # batch × head index
    pid_m  = tl.program_id(1)        # query tile index
    ...

TODOs

Bloque A — montar el wrapper

  • src/minimodel/attention_flash.py tiene una función Python flash_attn_forward(Q, K, V, causal=True) -> O.
  • Q, K, V son tensores (B, H, N, d) en fp16 sobre CUDA.
  • El wrapper elige BLOCK_M=64, BLOCK_N=64 para d=64 (o el apropiado para la dimensión de cabeza). Valida formas y dtypes, luego lanza el kernel Triton.

Bloque B — implementar el kernel

Esqueleto (rellena el cuerpo):

# Load Q tile into SRAM
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, d)
q_ptrs = Q_ptr + bh_off + offs_m[:, None] * stride_q_n + offs_d[None, :] * stride_q_d
q = tl.load(q_ptrs)

# Initialize accumulators
o = tl.zeros([BLOCK_M, d], dtype=tl.float32)
m = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
ell = tl.zeros([BLOCK_M], dtype=tl.float32)

# Inner loop over K, V tiles
for start_n in range(0, N, BLOCK_N):
    offs_n = start_n + tl.arange(0, BLOCK_N)
    # Load K tile
    k = tl.load(...)
    # Load V tile
    v = tl.load(...)
    # Compute S_ij = q @ k^T * sm_scale
    s = tl.dot(q, tl.trans(k)) * sm_scale
    # Apply causal mask if needed
    if IS_CAUSAL:
        s = tl.where(offs_m[:, None] >= offs_n[None, :], s, -float('inf'))
    # Online softmax update
    m_new = tl.maximum(m, tl.max(s, 1))
    alpha = tl.exp(m - m_new)
    p = tl.exp(s - m_new[:, None])
    ell = alpha * ell + tl.sum(p, 1)
    o = alpha[:, None] * o + tl.dot(p.to(v.dtype), v)
    m = m_new

# Normalize and store
o = o / ell[:, None]
tl.store(o_ptrs, o.to(tl.float16))
tl.store(L_ptrs, m + tl.log(ell))   # log-sum-exp (for backward; unused this lab)

Tú rellenas la aritmética de punteros, el enmascarado para colas no alineadas y los acumuladores fp32 donde haga falta.

Bloque C — implementar la referencia

reference_attn(Q, K, V, causal=True):

sm_scale = 1.0 / math.sqrt(d)
S = (Q @ K.transpose(-1, -2)) * sm_scale
if causal:
    S = S.masked_fill(causal_mask, float('-inf'))
P = torch.softmax(S, dim=-1)
return P @ V

Esta es la attention "naive" en PyTorch — lenta pero obviamente correcta.

Bloque D — verificar la corrección

  • Aleatorio (B, H, N, d) = (2, 4, 1024, 64) fp16 en CUDA.
  • Calcula O_ref = reference_attn(Q, K, V).
  • Calcula O_flash = flash_attn_forward(Q, K, V).
  • max_abs_error = (O_ref - O_flash).abs().max(). Asegura < 1e-3.
  • Repite para N ∈ {256, 1024, 4096} y causal ∈ {True, False}.

Bloque E — medir el speedup (opcional pero recomendado)

  • Wall-clock con torch.cuda.synchronize() y time.perf_counter().
  • Reporta tokens/s o matmul-FLOPs/s.
  • Esperado: 2–6× de speedup a N=4096 fp16 en GPU de consumo (3090 / 4090 / A10). En Hopper (H100), quizá más.

Bloque F — interpretar en README.md

Cuatro preguntas:

  1. ¿Qué max-abs-error conseguiste? ¿Se mantuvo bajo 1e-3 para todos los N que probaste?
  2. ¿Dónde te tropezó Triton? ¿Aritmética de punteros? ¿Formas de tl.dot? ¿Enmascarado causal? Sé específico — esta es la parte gruesa del README.
  3. Comparado con la predicción del lab 01, ¿qué speedup mediste? Si fue menor, ¿por qué? (Saturación de SRAM, sobrecarga de lanzamiento del kernel, autotune sin invocar, etc.)
  4. ¿Qué cambiarías para flash 2 (en lugar del flash 1 escrito arriba)? (Pista: intercambiar qué bucle es exterior — Q exterior en lugar de KV exterior. ¿Por qué importa para los tensor cores?)

Restricciones

  • Sólo fp16. bf16 es un añadido opcional; el camino fp32 añade complejidad.
  • Una dimensión de cabeza por kernel. No intentes hacer d una variable en tiempo de ejecución; usa tl.constexpr y recompila por d.
  • Sin autotune en la primera pasada. Elige BLOCK_M, BLOCK_N = 64, 64 y entrega. Añadir @triton.autotune es un paso de pulido.
  • Sin backward. Sólo forward esta fase.

Condiciones de parada

  • Kernel implementado en src/minimodel/attention_flash.py.
  • Los tests pasan; max-abs-error < 1e-3 en fp16 para N ∈ {64, 256, 1024, 4096} causal y no-causal (N=64 es la longitud de secuencia del corpus de verbos).
  • El README responde las cuatro preguntas.
  • Los cinco archivos del experimento commiteados.

Errores típicos

  • NaN en la salida. Causa más común: softmax de una fila de todos -inf (la máscara causal lo oculta todo). tl.exp(-inf) en Triton es 0, así que ℓ = 0, y divides por cero. Añade ℓ = tl.maximum(ℓ, 1e-30) antes de normalizar. Los kernels flash reales manejan esto con cuidado; para la versión educativa, la guarda basta.
  • Layout/strides incorrectos. El layout (B, H, N, d) de PyTorch tiene strides que dependen de si el tensor es contiguo. Siempre Q = Q.contiguous() antes de pasarlo, y recalcula los strides en cada llamada.
  • Forma de tl.dot que no concuerda. tl.dot requiere que las formas de ambos operandos sigan la convención de tensor-core de Triton (típicamente (BLOCK_M, d) @ (d, BLOCK_N)). Si pasas tl.trans(k) debería ser (d, BLOCK_N) — verifícalo.
  • Máscara causal off-by-one. La máscara usa >=, no >: la posición i se atiende a sí misma.
  • Desbordamiento de acumuladores fp16. Incluso tras restar el máximo, puede crecer. Acumula siempre m, ℓ, o en fp32 (haz cast sólo para la salida).

Cuándo consultar solutions/

Tras pasar los tests y cumplir el umbral del DoD. La referencia en solutions/02-flash-triton-ref.md (apertura de fase) recorre la aritmética de punteros línea a línea.


Siguiente lab: lab/03-paged-attn-reading.md.