English · Español
03 — El decode está limitado por memoria (y por qué eso reconfigura todo)¶
🇪🇸 La intensidad aritmética del paso de decode con caché es ~0.5 FLOPs/byte. Eso lo sitúa muy por debajo del machine balance de cualquier GPU moderna (típicamente 50–300 FLOPs/byte). Conclusión: el cuello de botella no es cómputo, es leer la caché de memoria. Todo lo demás que oirás sobre inferencia eficiente sigue de ahí.
Esta página hace por el paso de decode lo que phase-01/theory/03-roofline-model.md hizo por el matmul: sitúa el operador en el roofline y lee el régimen.
Planteamiento¶
Un solo paso de decode, una sola secuencia, una sola capa, las \(H\) cabezas a la vez. El caché tiene actualmente longitud \(S\).
El trabajo de atención por cabeza es:
- \(q \in \mathbb{R}^{1 \times d_h}\) (query del token actual para esta cabeza).
- \(K \in \mathbb{R}^{S \times d_h}\), \(V \in \mathbb{R}^{S \times d_h}\) (claves y valores cacheados para esta cabeza).
- \(\text{scores} = q K^\top / \sqrt{d_h} \in \mathbb{R}^{1 \times S}\).
- \(\text{weights} = \text{softmax}(\text{scores}) \in \mathbb{R}^{1 \times S}\).
- \(\text{out} = \text{weights} \cdot V \in \mathbb{R}^{1 \times d_h}\).
Repetir para las \(H\) cabezas independientemente.
FLOPs¶
Por cabeza: - \(q K^\top\): \(S \cdot d_h\) multiplicaciones + sumas = \(2 S d_h\) FLOPs. - softmax: ~\(5 S\) FLOPs (exp + suma + división + restar máximo), pero factor constante pequeño. Ignorar para asintótica. - \(\text{weights} \cdot V\): \(2 S d_h\) FLOPs.
Por cabeza, por capa, por paso: \(\approx 4 S d_h\) FLOPs (redondeando escalado y overhead del softmax).
Todas las cabezas, por capa, por paso: \(4 S \cdot d_h \cdot H = 4 S d\) FLOPs.
Todas las capas: \(4 L S d\) FLOPs.
Bytes movidos¶
Este es el cálculo crítico. Por cabeza, el operador debe leer:
- \(q\): \(d_h\) elementos. Diminuto.
- \(K\): \(S \cdot d_h\) elementos. Lineal en \(S\).
- \(V\): \(S \cdot d_h\) elementos. Lineal en \(S\).
- Escrituras: \(\text{out}\) con \(d_h\), más \(k_\text{new}, v_\text{new}\) con \(d_h\) cada uno. Diminuto.
Lecturas por cabeza: \(\approx 2 S d_h\) elementos, dominadas por K y V.
A lo largo de \(H\) cabezas en una capa: \(2 S d_h H = 2 S d\) elementos.
A lo largo de \(L\) capas: \(2 L S d\) elementos. A \(s\) bytes/elemento: \(\boxed{B_\text{decode-cache} = 2 L S d s}\) bytes.
(Ignoramos aquí los bytes de los pesos porque ya los contabilizamos en 01-prefill-vs-decode.md y dominan un operador distinto — el FFN. Aquí nos centramos en el operador de atención solo, para aislar por qué el caché es el cuello de botella.)
Intensidad aritmética¶
| dtype | \(s\) | \(I_\text{attn-decode}\) |
|---|---|---|
| fp32 | 4 | 0.5 FLOPs/byte |
| fp16 | 2 | 1.0 FLOPs/byte |
| int8 | 1 | 2.0 FLOPs/byte |
Fíjate en lo que no está en la fórmula. No hay \(S\). No hay \(L\). No hay \(d\). La intensidad aritmética del operador de atención de decode es una constante — determinada solo por el dtype del caché.
Esa constante es 0.5–2.0. Para comparar, una A100 tiene machine balance \(\pi / \beta \approx 312 / 2 \approx 156\) FLOPs/byte. Una H100 ronda 280. Un i5-8250U (números de la Fase 1) ronda 10.
En cualquier GPU moderna, la atención de decode corre al 0.3–1% del pico de FLOPS, con las FPUs ociosas esperando HBM. En el i5-8250U, corre a quizá el 5% del pico. El decoder está limitado por memoria en todas partes.
El hecho compañero de "lectura de pesos"¶
El operador de atención de arriba es solo la parte del caché. Pero cada paso de decode también relee los pesos del modelo para hacer el FFN, la proyección QKV, y la proyección de salida. Las matrices de pesos son las mismas a lo largo de todas las posiciones de secuencia:
- Proyección QKV: \(3 d^2\) elementos por capa.
- FFN: \(8 d^2\) elementos por capa.
- Proyección de salida: \(d^2\) elementos por capa.
Por capa por paso: \(\approx 12 d^2\) elementos de pesos. A lo largo de \(L\) capas: \(12 L d^2\) elementos. A \(s\) bytes/elemento: \(12 L d^2 s\) bytes — por paso de decode.
Para Llama-2-7B, fp16: \(12 \cdot 32 \cdot 4096^2 \cdot 2 \approx 12.9 \cdot 10^9\) bytes ≈ 13 GiB leídos por token.
En una A100 con 2 TB/s de HBM, eso son \(13 / 2000 \approx 6.5\) ms de tiempo de memoria puro por token, cota inferior. Los números reales son 10–15 ms/token — cerca de la cota. El modelo está decodificando tan rápido como el HBM puede escupir pesos, sin importar cuántas FPUs estén sentadas alrededor.
El diagrama del roofline del decode¶
Dibuja el operador de atención de decode en un roofline de GPU:
GFLOPS (log)
^
π ┤ ╭───────────── ← compute ceiling (e.g. 312 TFLOPS A100)
│ ╱
│ ╱
│ ╱
│ ╱
│ • decode attn (fp16) ╱ ← machine balance at ~156 FLOPs/byte
│ I=1.0, far left ╱
│ ╱
│ • decode attn (fp32)╱
│ I=0.5 ╱
│ ↓ ╱
│ on the slope ╱
└─────────────────┴─────────────────────────→
I_crit arithmetic intensity (log)
Los dos puntos de atención de decode están profundamente en la pendiente memory-bound, muy por debajo de la esquina. Para moverlos a la derecha (subir la intensidad):
- Cuantizar el caché. fp16 → int8 duplica la intensidad. int8 → int4 la duplica otra vez. (Fase 26.)
- Reducir filas del caché. GQA: compartir K, V entre grupos de cabezas, reduciendo los bytes de \(K, V\) leídos por cabeza de query. La matemática cambia de \(2 S d\) a \(2 S d \cdot (H_{KV} / H)\). (Fase 27.)
- Reducir columnas del caché. Atención de ventana deslizante: leer solo los últimos \(W \ll S\) tokens. Los bytes caen a \(2 L W d s\), constantes en \(S\) una vez \(S > W\). (Fase 27.)
- Reestructurar para leer el caché una vez por grupo de queries. Speculative decoding: \(k\) tokens candidatos validados en paralelo usan la misma lectura de caché. La intensidad efectiva sube \(k\) veces. (Fase 36.)
- Reestructurar para fusionar softmax con mat-mul (Flash-decoding). Hacer streaming del caché por SRAM, nunca materializar el softmax intermedio. La intensidad no cambia pero \(\beta\) sube efectivamente porque te quedas en SRAM más tiempo. (Fase 24.)
Cada una de ellas es un ataque distinto al mismo diagnóstico: el punto está demasiado a la izquierda en el roofline. Relee esta lista cuando leas los nombres de las optimizaciones en fases posteriores — ninguna es misteriosa una vez las ves como movimientos en este gráfico.
La vía de escape del batching¶
El batching tiene un papel especial en la historia de la memoria del decode.
Imagina que \(B\) usuarios decodifican concurrentemente. Los pesos del modelo se comparten entre usuarios — leídos una vez por paso, usados por \(B\) queries. El caché es por usuario (sigue siendo \(2 L S d s\) por usuario; total \(2 L B S d s\)).
Bytes de pesos por paso: \(12 L d^2 s\) (sin cambios — leídos una vez, aplicados a \(B\) queries).
FLOPs totales por paso (el FFN domina): \(\approx 24 L B d^2\).
Intensidad de pesos: \(24 L B d^2 / (12 L d^2 s) = 2 B / s\).
Así que agrupar por \(B\) eleva la intensidad aritmética del FFN por un factor de \(B\). A \(B = 16\) fp16: \(I = 16\). A \(B = 64\) fp16: \(I = 64\). Eso lleva el punto del FFN hacia la esquina. El punto de cache-attention, sin embargo, no se ve afectado por el batching — cada usuario lee su propio caché.
Por eso los sistemas de serving empujan duro los tamaños de batch (el FFN es el gasto pesado y se beneficia) pero PagedAttention es necesario (el caché por usuario es lo que bloquea el batching con \(B\) alto con secuencias de longitud variable).
Problemas de práctica¶
- Roofline del i5-8250U de la Fase 1 (de
docs/phase-01.../theory/03-roofline-model.md): \(\pi \approx 200\) GFLOPS, \(\beta \approx 20\) GB/s. Atención de decode en fp32: \(I = 0.5\). ¿Rendimiento alcanzable? ¿En fp16 (\(I = 1.0\))? - En una H100 (\(\pi \approx 990\) TFLOPS fp16, \(\beta \approx 3.35\) TB/s): ¿machine balance? ¿Dónde se sitúa decode-attention-fp16? ¿Qué fracción del pico?
- ¿Por qué bajar \(d_h\) (cabezas más pequeñas) no ayuda a la intensidad? (Pista: tanto FLOPs como bytes escalan linealmente en \(d_h\) — el cociente es invariante.)
- El MiniGPT de gramática decodifica
"Yesterday I worked"en fp32 en el i5-8250U de Borja. Predice la latencia por token de la atención de decode. A nuestra escala esto es ruido inmedible — el overhead del intérprete Python lo eclipsa. Explica por qué la fórmula sigue siendo la herramienta correcta: ¿cómo extrapolas de mediciones a escala MiniGPT para predecir la latencia de decode de Llama-2-7B?
Por qué esta página es la página para el resto del currículo¶
La Fase 22 es la primera vez que medirás un operador que resulta estar limitado por memoria. No será la última. De la Fase 22 en adelante, cada vez que alguien proponga "acelerar la inferencia", la pregunta que hacer es:
¿Este operador está actualmente limitado por cómputo o por memoria? Si por memoria, ¿qué le hace esta optimización a su intensidad aritmética?
Esa pregunta es a veces vergonzosamente esclarecedora. Muchas "kernel fusions" propuestas resultan ser no-ops porque no cambian los bytes movidos. Muchos "wins de cuantización" resultan componer exactamente porque suben directamente la intensidad. Una vez afilado el diagnóstico, el campo se vuelve legible.
Lo que esta página NO cubre¶
- El reordenamiento HBM↔SRAM que Flash-Attention/Flash-Decoding hace realmente. Nombrado aquí; el mecanismo es Fase 24 / 27.
- El asignador de bloques de PagedAttention. Esbozado en
theory/04-toward-paged-attention.md; derivado en la Fase 27. - Aritmética del caché cuantizado. Fase 26.
- Matemática de speculative decoding. Fase 36.
- Medición a nivel CUDA. Esta página argumenta desde la intensidad; la Fase 24 mide desde kernels.
Siguiente: theory/04-toward-paged-attention.md — una vista previa corta de por qué los cachés densos pre-asignados se rompen bajo workloads de serving realistas.