Skip to content

English · Español

00 — Por qué la atención domina la inferencia

🇪🇸 La atención no es el problema porque sea matemáticamente cara — lo es porque la matriz S = QKᵀ es enorme y se materializa en HBM. Flash no cambia la matemática; cambia qué bytes cruzan la barrera de memoria. Este archivo prepara el argumento del roofline para los siguientes tres.


El desglose del tiempo de reloj de la inferencia

Un paso forward de un transformer para un token, en un modelo con L capas, tamaño oculto h, dimensión de cabeza d y longitud de contexto N, hace aproximadamente:

  • L × MLP: dos operaciones Linear de tamaño (h, 4h) y (4h, h). FLOPs por capa = 2 × (2 × h × 4h) = 16 h². Bytes movidos = 2 × (h × 4h × 4) = 32 h² (pesos fp32).
  • L × Attention: ver abajo.
  • L × LayerNorm: O(h) — despreciable.

Para un modelo pequeño típico (L=12, h=768, d=64, N=2048):

  • FLOPs de MLP por capa = 16 × 768² ≈ 9.4M. A lo largo de L=12 capas: ~113M.
  • FLOPs de attention por capa: 4 N² d ≈ 1.05G. A lo largo de 12 capas: ~12.6G.

La attention es 100× más compute que MLP por capa a esta longitud de contexto. Y escala mientras MLP escala N (para un token nuevo; para prefill completo es N-ish para MLP, para attention de todos modos).

Así que la attention domina el compute. Pero aquí está la clave: también domina el tráfico de memoria, aún más.

Por qué la attention es bandwidth-bound

La attention naive materializa la matriz S = Q K^T de forma (N, N). Para N = 2048, son 4 × 2048² = 16 MiB en fp32 (u 8 MiB en fp16). Esta matriz se:

  1. Escribe a HBM (tras Q K^T).
  2. Lee de HBM (para softmax).
  3. Escribe a HBM (la matriz normalizada por softmax).
  4. Lee de HBM (para S @ V).

Tráfico total a HBM sólo en S: 64 MiB por capa en N=2048, fp32. Es el tamaño de una caché L2; el L2 de la GPU podría ser 40 MiB en una A100. Lo desbordamos.

Calcula la intensidad aritmética:

  • FLOPs: 4 N² d = 4 × 2048² × 64 = 1.07G FLOPs.
  • Bytes movidos (HBM): N² × 4 (escribir S) + N² × 4 (leer S para softmax) + N² × 4 (escribir softmax(S)) + N² × 4 (leer para @V) + términos pequeños para Q, K, V, O. Total ≈ 16 N² bytes ≈ 64 MiB.
  • Intensidad: 1.07e9 / 6.7e7 ≈ 16 FLOPs/byte.

El I_crit de una A100 (ratio compute-vs-ancho-de-banda-HBM): I_crit ≈ 312 TFLOPS / 1.55 TB/s ≈ 200 FLOPs/byte. La attention naive se sitúa en ~16 FLOPs/byte — >10× por debajo de la esquina. Casi toda la GPU está ociosa, esperando a HBM.

Esto es lo que FlashAttention resuelve. No cambiando los FLOPs (no lo hace), sino cambiando los bytes movidos (sí lo hace).

Lo que Flash hace de verdad, en un párrafo

FlashAttention particiona Q, K, V en tiles. Las dimensiones del tile interno se dimensionan para que el estado intermedio de un tile — S_tile de forma (B_r, B_c), más vectores corrientes de máximo y suma — quepa en SRAM on-chip (unos pocos KiB a ~100 KiB dependiendo de la GPU). La gran matriz (N, N) nunca se escribe a HBM. Sólo el bloque (B_r, d) por tile de salida y pequeñas estadísticas corrientes (B_r,) fluyen entre HBM y SRAM.

El truco matemático que permite que esto funcione es el online softmax: una recurrencia que te permite calcular softmax(S) @ V incrementalmente conforme llegan nuevos tiles de S, sin necesitar la fila completa de S primero. El archivo de teoría 01 deriva esto; el archivo de teoría 02 lo pone dentro del bucle de tiling.

La ganancia de intensidad: bytes movidos cae de ~16 N² a aproximadamente ~N · d · (3 + 2 N/B_c) (fp32). Para N=2048, d=64, B_c=64: bytes ≈ 2048 × 64 × 67 × 4 ≈ 33 MiB. Intensidad ≈ 1.07e9 / 3.5e7 ≈ 30 FLOPs/byte. ~2× más alta, y el ratio crece con N. Aún por debajo de la esquina de 200 FLOPs/byte de la A100, pero ahora estamos en la parte empinada de la pendiente del techo de memoria, no en su pie.

Esta es la razón por la que Flash es rápido.

Re-formulando el clásico "3× speedup"

El paper de FlashAttention reportó ~3× speedup de reloj sobre una baseline de PyTorch optimizada en N=2048 en A100. El número es específico al hardware y al ajuste, pero el mecanismo es la reducción del byte-count que acabamos de calcular. "Flash es 3× más rápido" debería leerse siempre en la sala de un ingeniero que sabe que realmente significa "Flash sube la intensidad aritmética manteniendo el working set en SRAM, así que el kernel se mueve más cerca del techo de cómputo en vez de arrastrarse por el techo de memoria".

Si Borja se lleva una frase de esta fase: los mismos FLOPs a mayor intensidad es todo el juego. La cuantización (Fase 26) ataca la intensidad por el lado de los bytes (pesos más pequeños). Flash (esta fase) la ataca por el lado algorítmico (no materializar lo que no necesitas). Ambos son argumentos del roofline; ambos son reales.

PagedAttention es un problema distinto

PagedAttention (vLLM) no es una optimización de kernel — es una optimización de asignador de memoria. El KV cache (almacenando todos los vectores K y V pasados para la generación autoregresiva) es enorme y crece por token. Un modelo de contexto largo con batch=32, N=8192, capas=32, cabezas=32, head_dim=128 tiene tamaño de KV cache = 2 × 32 × 8192 × 32 × 32 × 128 × 2 (fp16) ≈ 17 GiB. Asignar esto contiguamente por petición lleva a fragmentación masiva entre miembros del batch — como un SO que hace malloc y nunca free.

PagedAttention trata el KV como memoria virtual: pequeños bloques de tamaño fijo (p. ej., 16 tokens de KV por página), una tabla de páginas por petición, copy-on-write para cachear prefijos. El propio kernel de attention se modifica para seguir las indirecciones de la tabla de páginas en vez de acceder a un K, V plano.

Donde Flash ataca el tráfico HBM por kernel, PagedAttention ataca la utilización de memoria entre peticiones. Estos componen: un servidor de inferencia desplegado usa ambos.

Cubrimos PagedAttention como ejercicio de lectura (teoría 03, lab 03) porque re-implementarlo es un trabajo de ingeniería de servidor que distrae de la historia del kernel. La lectura anotada de vLLM es suficiente.

Otras tres variantes de attention en esta fase

  • Sliding window attention. Mistral et al. usan una ventana de contexto de ancho fijo: cada token atiende sólo a los últimos W < N tokens. Reduce la complejidad de a N·W. Componible con Flash (el kernel sólo enmascara posiciones fuera de la ventana).
  • Grouped/Multi-Query Attention (GQA/MQA). Comparten K, V entre múltiples cabezas de query. Reducen el tamaño del KV cache en n_kv_groups / n_heads (típicamente 4×–8×). No reducen el compute por token, pero reducen radicalmente el tráfico de memoria por token durante el decode autoregresivo.
  • Multi-Latent Attention (MLA, DeepSeek). Comprime K y V en un espacio latente de bajo rango; reconstruye K/V al vuelo por llamada de attention. Cambia un pequeño extra de compute (la proyección) por un KV cache mucho más pequeño.

Los tres son argumentos del roofline. La teoría 04 recorre cada uno.

Lo que esta fase no intenta

No derivamos Flash backward. El paso backward usa recomputación: re-deriva S desde Q y K al vuelo durante el cómputo del gradiente, cambiando FLOPs por memoria. El online softmax del camino forward no ayuda directamente — el backward necesita una regla de actualización distinta. Fuera de alcance para la Fase 27; volverá en una fase futura.

Recap de un párrafo

La attention domina la inferencia del transformer tanto en FLOPs como en tráfico de memoria. La attention naive se sitúa ~50× por debajo de la esquina del roofline de la GPU porque materializar la matriz (N,N) S = QKᵀ desborda HBM. FlashAttention particiona Q/K/V en tiles residentes en SRAM y usa una recurrencia de online softmax para evitar materializar S, recortando los bytes movidos en un orden de magnitud y elevando el punto 5–30× hacia el techo de cómputo. PagedAttention ataca un cuello de botella distinto — fragmentación del KV cache entre peticiones en un servidor. GQA/MQA/MLA reducen el propio KV cache. Los archivos de teoría restantes derivan cada idea; los labs implementan Flash forward en Triton y leen PagedAttention en vLLM.

Siguiente: theory/01-online-softmax.md.