Skip to content

English · Español

03 — Modelo de coste del decode: por qué necesitamos la KV cache

🇪🇸 La cuenta de coste del decode sin caché es brutal: cada token cuesta más que el anterior porque el modelo recomputa la atención sobre todo el prefijo. Este capítulo cuantifica el desperdicio y motiva la KV cache de Fase 22.

El coste pelado

Generar \(T\) tokens dado un prompt de longitud \(L\):

Paso Longitud que ve el forward Coste del forward
Generar token 0 \(L\) \(O(L \cdot d^2)\)
Generar token 1 \(L + 1\) \(O((L+1) \cdot d^2)\)
Generar token 2 \(L + 2\) \(O((L+2) \cdot d^2)\)
... ... ...
Generar token \(T-1\) \(L + T - 1\) \(O((L+T-1) \cdot d^2)\)

Total: \(\sum_{t=0}^{T-1} O((L+t) \cdot d^2) = O\big(T \cdot L \cdot d^2 + T^2 \cdot d^2 / 2\big)\).

Para nuestro Mini-GPT (\(d = 64, L = 8, T = 20\)):

  • Término lineal: \(20 \cdot 8 \cdot 64^2 = 655{,}360\) ops.
  • Término cuadrático: \(20^2 \cdot 64^2 / 2 = 819{,}200\) ops.

Los dos son comparables aquí. A mayor escala (p. ej., \(L = 1024, T = 200, d = 4096\)), el término cuadrático domina por órdenes de magnitud.

Dónde está el trabajo redundante

Mira el forward del Mini-GPT (Fase 17). En cada bloque transformer, la atención calcula:

\[Q, K, V = X W_Q, X W_K, X W_V \quad \text{shape } (T, d)$$ $$A = \text{softmax}(QK^\top / \sqrt{d_h}) \quad \text{shape } (T, T)$$ $$\text{output} = A V \quad \text{shape } (T, d)\]

Cuando generamos el token \(t+1\):

  • Pasamos la secuencia completa de longitud \(L + t + 1\) como entrada.
  • Calculamos \(Q, K, V\) para las \(L + t + 1\) posiciones.
  • Pero: \(Q, K, V\) para las posiciones \(0, 1, \ldots, L + t - 1\) son los mismos que en el paso anterior. Solo la posición \(L + t\) es nueva.

Así que en el paso \(t+1\) recomputamos \(K, V\) para \(L + t\) posiciones que ya calculamos en el paso \(t\). Este es el desperdicio que queremos eliminar.

La idea de la KV cache (referencia adelantada)

La Fase 22 implementará esto: almacenar \(K^{(l)}, V^{(l)}\) para cada capa \(l\) y cada posición vista hasta ahora. En el paso \(t+1\), solo computar los \(K, V\) para la nueva posición y concatenar a la cache.

# After Phase 22:
def decode_step(model, last_token, cache: KVCache) -> tuple[int, KVCache]:
    """Process one new token, updating the cache."""
    logits, new_cache = model.forward_one(last_token, cache)
    next_token = sample(logits[-1])
    return next_token, new_cache

El coste por paso pasa de \(O((L + t) \cdot d^2)\) a \(O(d^2)\)constante en \(t\) (ignorando el escaneo lineal-en-longitud-cacheada de la atención, mucho más barato que el matmul). El coste total pasa a \(O((L + T) \cdot d^2)\).

Para nuestro ejemplo (\(L = 8, T = 20, d = 64\)): \(28 \cdot 4096 = 114{,}688\) ops. Eso es ~13× más barato que los \(1{,}474{,}560\) pelados de arriba.

Por qué la Fase 21 no usa la cache

Razones pedagógicas:

  1. Sentir el coste. El benchmark de decode del lab 03 sobre \(T = 50\) tokens será visiblemente lento. Deberías ver cómo sube la curva \(t \cdot\)coste.
  2. Limpieza. Sin la cache, el bucle de decode tiene la misma forma que el forward de entrenamiento. Sin nueva maquinaria de estado. La Fase 22 introduce el estado y lo testea contra la referencia sin cache.
  3. Corrección primero. La cache es una optimización; debe producir salidas idénticas bit a bit al decode sin cache (módulo reordenamiento de punto flotante). El primer test de la Fase 22 será: cache vs sin-cache, ¿coinciden las salidas?

Coste de memoria (preview)

La KV cache almacena \(K^{(l)}, V^{(l)}\) para cada capa \(l\) y cada posición. Tamaño:

\[\text{KV cache size} = 2 \cdot n_\text{layers} \cdot (L + T) \cdot d \cdot \text{bytes-por-float}\]

Para el Mini-GPT (\(n_\text{layers} = 2, d = 64, L + T = 28\), float32 = 4 bytes): \(2 \cdot 2 \cdot 28 \cdot 64 \cdot 4 = 28{,}672\) bytes = 28 KB. Diminuto.

Para GPT-3 (\(n_\text{layers} = 96, d = 12{,}288, L + T = 2048\), fp16): \(2 \cdot 96 \cdot 2048 \cdot 12{,}288 \cdot 2 = 9.7\) GB. Enorme.

Por eso la Fase 27 cubre optimizaciones de KV cache: paged attention, multi-query attention, grouped-query attention. La cache es el coste de memoria dominante en la inferencia de LLMs a escala.

Una nota sobre continuous batching (referencia adelantada)

Si tienes muchos requests, puedes hacer batching de sus pasos de decode: en el paso \(t\), ejecuta un forward que procese (el token \(t_A\) del request A + el token \(t_B\) del request B + ...) juntos. El truco: los requests tienen distintas longitudes de prompt y distintas longitudes de generación. La Fase 33 ("Inference Serving: From FastAPI to Continuous Batching") lo cubre en detalle.

La Fase 21 no hace batching. Solo decode de un único request. La infraestructura para batching se añade encima en la Fase 33.

Lo que este archivo NO cubre

  • Implementación de la KV cache. Fase 22.
  • Paged attention, MQA, GQA. Fase 27.
  • Continuous batching. Fase 33.

Siguiente: ../lab/00-greedy.md