Skip to content

English · Español

02 — FlashAttention como una optimización de roofline

🇪🇸 FlashAttention es Fase 1 en acción. Mismo número de FLOPs que la atención naive; menos bytes movidos a HBM porque la matriz S=QKᵀ nunca se materializa. La intensidad aritmética sube; el punto del roofline se mueve hacia el techo de cómputo. No es un truco — es álgebra (recurrencia online) más un layout de tiles.

Este archivo es la pieza central de la Fase 27. Léelo una vez, luego re-léelo con theory/01-online-softmax.md abierto en otra pestaña. Al final deberías poder (a) dibujar la ejecución tile-a-tile, (b) derivar simbólicamente el delta del byte-count, © enunciar el desplazamiento del roofline.


El algoritmo naive de attention

Para una cabeza con Q, K, V ∈ ℝ^{N × d}, salida O ∈ ℝ^{N × d}:

1. S = Q @ K^T              # (N, N) — materializada en HBM
2. P = softmax_rowwise(S)   # (N, N) — materializada en HBM
3. O = P @ V                # (N, d)

Contabilidad de tráfico HBM (lectura = R, escritura = W; ignora O puesto que es la misma en ambos algoritmos):

  • Paso 1: R(Q) + R(K) + W(S) = Nd + Nd + N² lecturas/escrituras (fp32 → ×4 bytes).
  • Paso 2: R(S) + W(P) = N² + N² = 2N².
  • Paso 3: R(P) + R(V) = N² + Nd.

Total: Nd × 3 + N² × 4 elementos fp32 = (12 Nd + 16 N²) bytes.

Para N=2048, d=64: 12 × 2048 × 64 + 16 × 2048² = 1.6 MiB + 64 MiB = 65.6 MiB.

FLOPs: 2 × N² × d (Q@K^T) + 5 × N² (softmax) + 2 × N² × d (P@V) = 4 N² d + 5 N². Para nuestros números: 4 × 2048² × 64 + 5 × 2048² = 1.07 GF + 21 MF = 1.09 GFLOPs.

Intensidad: 1.09e9 / 6.88e7 = 15.8 FLOPs/byte. Para una A100 con I_crit ≈ 200, esto es 13× por debajo de la esquina. Memory-bound.

El algoritmo Flash (forward)

Elige tamaños de tile: - B_r — filas de Q por tile externo. - B_c — filas de K (y V) por tile interno.

Restricción: el estado intermedio de un tile debe caber en SRAM. El estado por tile es S_tile ∈ ℝ^{B_r × B_c} (el bloque parcial de producto interno), más vectores m, ℓ por fila de B_r (escalares por fila de query). Para presupuesto SRAM M_sram:

\[ B_r \times B_c \times 4 \, (\text{fp32}) + B_r \times d \times 4 \, (\text{Q tile}) + B_c \times d \times 4 \, (\text{K tile}) + B_c \times d \times 4 \, (\text{V tile}) \leq M_{\text{sram}} \]

Para d=64, M_sram=64 KiB, B_r = B_c = 64 funciona: 64×64×4 + 3×64×64×4 = 16 KiB + 48 KiB = 64 KiB. ✓

El algoritmo:

for i = 0 .. (N / B_r) - 1:                # bucle externo: tiles de Q
    Q_i = Q[i*B_r:(i+1)*B_r, :]           # (B_r, d), cargar a SRAM
    O_i = zeros(B_r, d)                    # acumulador en SRAM
    m_i = full(B_r, -inf)                  # máximo corriente por fila
    ℓ_i = zeros(B_r)                       # suma corriente por fila
    for j = 0 .. (N / B_c) - 1:           # bucle interno: tiles de K, V
        K_j = K[j*B_c:(j+1)*B_c, :]       # (B_c, d), cargar a SRAM
        V_j = V[j*B_c:(j+1)*B_c, :]       # (B_c, d), cargar a SRAM
        S_ij = Q_i @ K_j^T / sqrt(d)      # (B_r, B_c), en SRAM
        m_new = max(m_i, rowmax(S_ij))
        α = exp(m_i - m_new)               # (B_r,)
        P_ij = exp(S_ij - m_new[:, None])  # (B_r, B_c)
        ℓ_i = α * ℓ_i + rowsum(P_ij)
        O_i = α[:, None] * O_i + P_ij @ V_j   # la actualización online
        m_i = m_new
    O[i*B_r:(i+1)*B_r, :] = O_i / ℓ_i[:, None]   # normalizar, escribir a HBM

La recurrencia del online softmax de la teoría 01 es exactamente lo que hay dentro del bucle interno. El tiling es lo que la envuelve.

Bytes movidos por Flash

El gran cambio: S nunca cruza HBM. Vive sólo en SRAM, calculada y consumida por paso interno (i, j).

Tráfico HBM:

  • Q se lee una vez por tile externo, total Nd elementos.
  • K, V cada uno se lee N / B_r veces (una por iteración externa), total 2 × Nd × N/B_r elementos.
  • O se escribe una vez al final de cada iteración externa, total Nd.
  • m, ℓ son despreciables (O(N) total, no ).

Total: Nd × (2 + 2N/B_r) elementos fp32 = (8 Nd × (1 + N/B_r)) bytes.

Para N=2048, d=64, B_r=64: 8 × 2048 × 64 × (1 + 32) = 8 × 2048 × 64 × 33 ≈ 33 MiB.

Compara con los 65.6 MiB del naive. Eso es 2× menos.

Espera — ¿sólo 2×? El "3× más rápido" implica más.

Dos respuestas:

  1. La contabilidad de bytes aquí es generosa con el naive. Una implementación real de PyTorch también calcula S en fp32 incluso con Q, K fp16 (por estabilidad del softmax), luego castea de vuelta. El byte-count real movido está más cerca de ~24 N² (3× nuestra estimación de 16 N²).
  2. La foto del roofline importa más que el byte-count. Incluso si los bytes fueran iguales, los tiles de Flash caben en SRAM. El techo relevante para Flash no es el bandwidth HBM — es el bandwidth SRAM (en A100, ~19 TB/s, 12× más alto que HBM). El "techo de memoria" para los kernels de Flash es una línea más alta. Mismos FLOPs / menos bytes efectivos (contando contra el techo SRAM) = mucha mayor intensidad.

El one-liner más limpio: Flash cambia bandwidth HBM por bandwidth SRAM. Los bytes totales movidos por kernel podrían ser similares, pero los bytes que cruzan la frontera lenta (HBM ↔ SRAM) son muchos menos.

Re-formulando contra el roofline de la Fase 1

De docs/phase-01-hardware-substrate/theory/03-roofline-model.md, la ecuación del roofline es perf = min(π, I × β). Dos regímenes: memory-bound (por debajo de I_crit = π/β) y compute-bound (por encima).

Para attention naive en A100: I ≈ 16 FLOPs/byte, muy por debajo de I_crit ≈ 200. Techo de rendimiento: 16 × 1.55 TB/s = 25 TFLOPS. De 312 pico — 8% de utilización.

Para Flash en A100: los bytes HBM movidos caen; el β relevante si contamos sólo tráfico HBM da I_effective ≈ 100+ FLOPs/byte. Techo de rendimiento: 100 × 1.55 TB/s = 155 TFLOPS. Mitad del pico.

El punto se movió 6× pendiente arriba. Esa es la afirmación del "3× más rápido", re-derivada desde primeros principios. (Es 6× en el roofline, pero en la práctica el speedup realizado es menor porque el kernel no puede saturar el bandwidth SRAM perfectamente y tiene otros overheads.)

Este es el argumento del roofline que Borja debería tener a flor de dedo. Cuantización (Fase 26) recorta bytes reduciendo el tamaño por elemento. Flash recorta bytes evitando materialización intermedia. Ambos empujan el punto hacia arriba.

Por qué Flash es exacto, no aproximado

Un malentendido común: "Flash es una aproximación porque procesa tiles." Falso.

La recurrencia del online softmax (teoría 01) es una identidad, no una aproximación. Cada actualización tile-a-tile produce el mismo O/ℓ final que el softmax todo-a-la-vez salvo redondeo de coma flotante. El redondeo no es peor que el naive — de hecho, a menudo ligeramente mejor, porque la reescala corriente de Flash tiende a mantener los números en un rango estrecho.

Empíricamente: O_flash - O_naive tiene error abs máximo ~1e-6 en fp32, ~1e-3 en fp16, en inputs de prueba estándar. Es el mismo orden que el redondeo que el propio Naive acumula.

Esto es lo que el umbral del DoD (1e-3 en fp16) comprueba en el lab 02.

Lo que Flash forward no hace

  1. No ayuda con la memoria de entrenamiento. Almacenar S era un coste de memoria ( por capa por batch). Flash evita ese almacenamiento. Pero para el backward, necesitamos recomputar S desde Q, K — ahorrando almacenamiento en memoria a costa de FLOPs. Fuera de alcance para esta fase (sólo forward).
  2. No acelera attention sin softmax. Linear attention, kernel attention, etc., no tienen softmax — el mecanismo de Flash no aplica directamente.
  3. No ayuda con secuencias cortas. Para N ≤ 256, la matriz (N,N) cabe en SRAM trivialmente. El overhead de tiling de Flash puede superar la ganancia. La heurística de PyTorch de usar Flash sólo para N ≥ 512 refleja esto.
  4. No optimiza attention con head dim muy grande. Para d=128, el presupuesto SRAM por tile se vuelve estrecho; B_r, B_c deben encoger, reduciendo la intensidad aritmética dentro de los tiles. Flash 2 (el paper de seguimiento) re-balancea las dimensiones del tile para d grande.

Una nota sobre Flash 1 vs Flash 2

El paper original de FlashAttention (Dao et al., 2022) tenía un bucle externo sobre los tiles de K, V y un bucle interno sobre los tiles de Q — al revés de lo que escribimos arriba. Flash 2 (Dao, 2023) los intercambió porque Q-outer reduce los FLOPs no-matmul y encaja mejor con los tensor cores de Hopper.

Para la Fase 27 implementamos Flash 2 (Q-outer), pero el contenido algebraico es idéntico. El código del kernel difiere sólo en qué bucle es el externo. El lab 02 especifica Flash 2.

Problemas de práctica

Soluciones al abrir la fase en solutions/02-flash-attention-ref.md. Razona, no ejecutes.

  1. Calcula los bytes HBM movidos por Flash para N=8192, d=128, B_r=64, B_c=64 en fp16. Compara con los bytes HBM del naive. Enuncia la ratio de speedup (puramente desde bytes).
  2. El presupuesto SRAM en Hopper H100 es ~228 KB por SM. Elige B_r, B_c para d=128 fp16 tal que 4 tiles (Q, K, V, S) quepan. ¿Cuál es el máximo B_r × B_c que puedes permitirte?
  3. Sliding window attention con ventana W=512 en una secuencia de N=8192. ¿Cómo cambia el bucle interno? ¿Cuántos tiles de K, V necesita leer cada tile de Q? Compara con Flash denso.
  4. Muestra que para B_c → 1 (una fila K/V por tile interno), Flash degenera a calcular softmax en serie sobre N términos con máximo/suma corrientes. ¿Por qué no es esto útil? (Pista: piensa en utilización de tensor cores.)

Recap de un párrafo

FlashAttention tilea Q, K, V en bloques residentes en SRAM y usa la recurrencia del online softmax para evitar materializar la matriz (N, N) S = QKᵀ en HBM. Los FLOPs son idénticos a la attention naive. Los bytes que cruzan la frontera lenta HBM↔SRAM son 2–10× menos (dependiendo de los tamaños de tile y la head dim), y el working set por tile vive en un techo de bandwidth mucho más rápido (SRAM). En el roofline de la Fase 1, esto se traduce en un punto mucho más cerca del techo de cómputo — la fuente del speedup de reloj de 3–10× de Flash. El algoritmo es exacto salvo redondeo de coma flotante; no es una aproximación. El siguiente archivo de teoría extiende el framing del byte-count a PagedAttention (una capa distinta del stack — KV cache, no el kernel mismo).

Siguiente: theory/03-paged-and-sliding.md.