English · Español
01 — Prefill vs Decode: dos fases de un mismo forward¶
🇪🇸 Inferencia autorregresiva = una pasada paralela inicial (prefill) sobre el prompt entero, seguida de muchas pasadas seriales (decode), una por token nuevo. La asimetría entre las dos es lo que organiza todo el resto de la infraestructura de servir LLMs.
Esta página formaliza qué son realmente el "prefill" y el "decode" a nivel de operador — mismo modelo, distinta forma de cómputo. Una vez nombrada la forma, los costes siguen algebraicamente.
Planteamiento¶
Fija un transformer causal con \(L\) capas, \(H\) cabezas de atención por capa, dim de cabeza \(d_h\), dim del modelo \(d = H \cdot d_h\), dim oculta del FFN \(d_\text{ffn} \approx 4d\), vocab size \(V\), tamaño de batch \(B\).
Para el MiniGPT de gramática (corpus §A13): \(L = 4\), \(H = 4\), \(d_h = 16\), \(d = 64\), \(d_\text{ffn} = 256\), \(V \approx 600\), \(B = 1\). Números triviales — cada fórmula de esta página puede evaluarse en una servilleta.
Un usuario envía un prompt de longitud \(P\). El sistema debe producir \(D\) tokens nuevos. Secuencia total al final: \(S = P + D\). Para el ejemplo recurrente "Yesterday I worked and he": \(P = 2\) (prefill "Yesterday I"), \(D = 3\) ("worked", "and", "he"), \(S = 5\).
Fase de prefill¶
Entrada: \(P\) tokens (el prompt). Salida: el hidden state del último token del prompt (usado para muestrear el primer token nuevo) y un KV cache poblado de longitud \(P\) en cada capa.
Por capa, el prefill calcula:
- \(X \in \mathbb{R}^{P \times d}\) — entradas embebidas.
- \(Q = X W_Q\), \(K = X W_K\), \(V = X W_V\). Cada una es \(P \times d\). Proyecciones lineales: \(P \cdot d^2\) FLOPs cada una, tres veces. Total: \(3 P d^2\) FLOPs.
- Atención: \(A = \text{softmax}(QK^\top / \sqrt{d_h} + M) V\), donde \(M\) es la máscara causal. El \(QK^\top\) es \(P \times P\) por cabeza; \(P \cdot d_h \cdot P\) FLOPs por cabeza; \(H\) cabezas \(\to\) \(P^2 \cdot d\) FLOPs. El \(\cdot V\) son otros \(P^2 \cdot d\). Atención total: \(2 P^2 d\) FLOPs.
- Proyección de salida \(A W_O\): \(P \cdot d^2\) FLOPs.
- FFN: dos matmuls \((P \times d) \cdot (d \times d_\text{ffn}) \cdot (d_\text{ffn} \times d)\). Con \(d_\text{ffn} = 4d\): \(8 P d^2\) FLOPs.
Total por capa: \(3 P d^2 + 2 P^2 d + P d^2 + 8 P d^2 = 12 P d^2 + 2 P^2 d\).
A lo largo de \(L\) capas: \(\boxed{F_\text{prefill} \approx 12 L P d^2 + 2 L P^2 d}\).
Dos regímenes: - Prompt corto (\(P \ll d\)): dominado por \(12 L P d^2\) — lineal en \(P\), cuadrático en \(d\). Matmul-bound. - Prompt largo (\(P \gtrsim d\)): el término de atención \(2 L P^2 d\) domina. Cuadrático en \(P\).
Para Llama-2-7B (\(L=32\), \(d=4096\)): los dos términos se cruzan en \(P \approx 6d = 24576\). Por debajo, los FFN son la mayoría del coste; por encima, lo es la atención. Para un prompt de 4096 tokens, FFN ~85%, atención ~15%. Para nuestro MiniGPT de gramática (\(L=4\), \(d=64\), típico \(P=2\)): \(12 \cdot 4 \cdot 2 \cdot 64^2 = 393\) K FLOPs — es decir, nada. El prefill de "Yesterday I" lleva microsegundos. Que la maquinaria del caché sea correcta sigue importando: el mismo código corre con 175 B de parámetros.
Tráfico de memoria. El prefill lee los pesos (\(\sim 12 L d^2\) bytes) una vez, lee \(X\) una vez. Las activaciones \(Q, K, V, A\) son \(P \cdot d\) cada una y viven en caché o se derraman a DRAM. Intensidad aritmética = \(F_\text{prefill}\) / bytes_movidos \(\approx P\) — crece con \(P\). El prefill está limitado por cómputo para cualquier prompt no trivial.
Fase de decode¶
Entrada en el paso \(t\): un token nuevo, más el caché poblado de los pasos \(1..t-1\). Longitud actual del caché: \(S = P + t - 1\). Salida: el hidden state del nuevo token, y el caché extendido a longitud \(S+1\).
Por capa, por paso de decode, calculamos:
- \(x \in \mathbb{R}^{1 \times d}\) — el nuevo token embebido.
- \(q = x W_Q\), \(k_\text{new} = x W_K\), \(v_\text{new} = x W_V\). Cada uno es \(1 \times d\). Coste: \(3 d^2\) FLOPs por capa.
- Añadir \(k_\text{new}\) y \(v_\text{new}\) al caché. El caché se convierte en \(K \in \mathbb{R}^{(S+1) \times d}\), \(V \in \mathbb{R}^{(S+1) \times d}\).
- Atención: \(a = \text{softmax}(q K^\top / \sqrt{d_h}) V\). Aquí \(q\) es \(1 \times d_h\) por cabeza; \(K\) es \((S+1) \times d_h\) por cabeza. El \(qK^\top\) es \(1 \times (S+1)\) por cabeza; \(d_h \cdot (S+1)\) FLOPs por cabeza; \(H\) cabezas \(\to (S+1) \cdot d\) FLOPs. El \(\cdot V\) añade otros \((S+1) \cdot d\). No hace falta máscara — la query es de longitud 1, el caché solo contiene tokens pasados (y el actual que se acaba de añadir, lo cual está bien: es la diagonal de la máscara causal). Atención total: \(2(S+1) d\) FLOPs.
- Proyección de salida: \(d^2\) FLOPs.
- FFN: \(8 d^2\) FLOPs.
Por capa por paso: \(12 d^2 + 2(S+1) d\).
A lo largo de \(L\) capas, por paso de decode: \(\boxed{F_\text{decode-step} \approx 12 L d^2 + 2 L S d}\).
Sumando sobre \(D\) pasos de decode, con \(S\) creciendo de \(P\) a \(P+D-1\):
Para generación larga (\(D \gg P\), \(D \gg d\)): el término \(L D^2 d\) domina. Cuadrático en tokens generados, incluso con el caché. El caché convirtió cúbico (\(\Theta(D^3)\) sin caché) en cuadrático (\(\Theta(D^2)\) con caché).
El desastre sin caché (sanity check)¶
Sin el caché, cada paso de decode re-ejecuta el prefill sobre \(S\) tokens. El paso \(t\) cuesta \(F_\text{prefill}(S=P+t)\) ≈ \(12 L (P+t) d^2 + 2 L (P+t)^2 d\). Sumando:
Eso es \(\Theta(D^3)\) en el régimen de generación larga — cúbico en tokens generados. El caché te compra un factor de \(D\). Para \(D = 1000\), son 1000×. Esto no es una optimización menor; es la diferencia entre "inferencia usable" y "ninguna inferencia".
Asimetría de tráfico de memoria¶
Esta es la tabla más importante de la Fase 22. Memorízala.
| Magnitud | Prefill | Decode (por paso) |
|---|---|---|
| FLOPs | \(12 L P d^2 + 2 L P^2 d\) | \(12 L d^2 + 2 L S d\) |
| Bytes movidos (pesos) | \(\sim 12 L d^2\) (una vez) | \(\sim 12 L d^2\) (¡cada paso!) |
| Bytes movidos (caché) | ninguno (caché vacío al inicio) | \(\sim 2 L S d \cdot s\) por paso |
| Intensidad aritmética | \(\sim P\) (crece con el prompt) | \(\sim O(1)\) — en realidad ~0.5 FLOPs/byte |
| Cuello de botella | Cómputo (FLOPS) | Ancho de banda de memoria (BW) |
Mira la fila de pesos. El decode relee toda la matriz de pesos en cada paso, aunque solo haga \(O(d^2)\) FLOPs contra ella (no \(O(P d^2)\)). Eso fija la intensidad aritmética del decode para las capas FFN en ~1 (1 FLOP por byte cargado) — independiente del tamaño del modelo o de la longitud de contexto.
Por eso un modelo de 70 B parámetros puede generar a quizá 10 tokens/s en una A100, mientras los 312 TFLOPS nominales de la GPU están al ~99% ociosos. Toda la jerarquía de memoria es el limitador de tasa. (El MiniGPT de gramática decodificando "Yesterday I worked" también está limitado por memoria, pero a una escala donde no puedes ver el cuello de botella — los cachés caben todos en L1. El experimento de curva de coste del DoD te obliga a extrapolar a un régimen donde sí importaría.)
También por eso el batching ayuda al decode dramáticamente. Si 16 usuarios decodifican concurrentemente contra el mismo modelo, la lectura de pesos se amortiza 16 veces — la GPU lee la matriz FFN una vez por paso y la aplica a 16 queries. La intensidad aritmética pasa de 1 a ~16. Este es el modelo de negocio completo de la infraestructura de serving de LLMs (Fase 28).
Lo que la dicotomía te dice sobre cada optimización que encontrarás¶
Úsalo como anillo decodificador:
- Flash-Attention — reestructura la atención del prefill para evitar materializar la matriz \(P \times P\). Gana en prefill con \(P\) largo. No ayuda al decode (no hay \(P \times P\) que evitar).
- Flash-Decoding — reestructura la atención del decode para paralelizar mejor el paso fila-contra-caché. Gana en decode con \(S\) largo. No ayuda al prefill.
- PagedAttention — reestructura el layout del caché para manejar serving por lotes de longitud variable. Afecta al layout de memoria en decode. Gana al servir a muchos usuarios con longitudes de secuencia distintas.
- Continuous batching — desacopla el prefill por-usuario de un decode batch compartido. Reorganiza el schedule, no la matemática.
- GQA / MQA — reduce \(H\) (o comparte K, V entre cabezas). Encoge linealmente el caché. Linealmente mejora el decode (menos que leer), apenas toca el prefill (ahí está compute-bound de todos modos).
- Cuantización (caché int8) — divide \(s\) por la mitad. Encoge linealmente la memoria del caché y el tráfico de memoria del decode. Toca al prefill solo en la medida en que reduce los bytes de pesos.
- Speculative decoding — hace varios pasos de decode "gratis" validando suposiciones de un modelo barato en paralelo. Reorganiza el decode para que se parezca más al prefill (paralelo en vez de serial).
Cada una de ellas lleva el nombre de un síntoma. Los síntomas son entradas de la tabla de arriba.
Pseudo-pseudocódigo para la Fase 22¶
def generate(prompt, max_new_tokens):
# Prefill: parallel pass over the prompt.
cache = KVCache.allocate(...)
h = embed(prompt)
for layer in layers:
q, k, v = layer.qkv(h)
cache.append(layer_idx, k, v) # cache filled to length P
h = layer.attention(q, k, v, causal_mask=True)
h = layer.ffn(h)
next_token = sample(h[-1]) # h[-1] is the last prompt token
# Decode: one new token at a time.
for step in range(max_new_tokens):
h = embed(next_token).unsqueeze(0) # shape: (1, d)
for layer in layers:
q, k_new, v_new = layer.qkv(h)
cache.append(layer_idx, k_new, v_new) # cache grows by 1
K, V = cache.read(layer_idx) # the full cached K, V
h = layer.attention(q, K, V, causal_mask=False) # q is len 1; no mask needed
h = layer.ffn(h)
next_token = sample(h[0])
yield next_token
Esto es esencialmente lo que Borja implementará en lab/01-implement-cache.md. Los dos bucles for layer in layers son el mismo código — solo difieren en la forma de \(Q\). La capa de caché es lo que hace que ese cambio de forma funcione sin coste cuadrático por paso.
Lo que esta página NO cubre¶
- Derivación de bytes-de-caché. Esbozada en la tabla; derivación completa en
theory/02-memory-cost.md. - Argumento de intensidad aritmética para el decode. Afirmado como ~0.5 FLOPs/byte; derivado en
theory/03-decode-as-memory-bound.md. - Especificidades de GPU. Las fórmulas de \(F\) y bytes son independientes del hardware. La Fase 23 las mapea sobre SMs, warps, HBM.
- Pasada hacia atrás / entrenamiento. El caché es solo para decode. El entrenamiento usa atención en forma de prefill completo en cada paso.
Siguiente: theory/02-memory-cost.md — derivar la fórmula de bytes y aplicarla al MiniGPT de gramática, Llama-2-7B y GPT-3.