Skip to content

English · Español

00 — Por qué existe un KV Cache

🇪🇸 La caché KV existe porque la generación autorregresiva es estructuralmente repetitiva: en cada paso, la atención del nuevo token mira a exactamente las mismas claves y valores que ya miró el token anterior, más una fila nueva. Recalcularlas todas en cada paso es trabajo redundante que multiplica el coste por O(n).

Esta es la página de motivación. Léela antes que las páginas de derivación — sin el porqué, las fórmulas son áridas. Con el porqué, son inevitables.


El ejemplo trabajado

A lo largo de la Fase 22 anclamos cada concepto a un prompt concreto extraído del corpus de gramática de verbos §A13:

Prompt:  "Yesterday  I"        (2 tokens — adverbial de tiempo pasado + pronombre singular de 1ª persona)
Decode → "Yesterday  I  worked"           (un token nuevo)
Decode → "Yesterday  I  worked  and"      (otro token nuevo)
Decode → "Yesterday  I  worked  and  he"  (y otro)

El MiniGPT de gramática (Fase 17, vocab ≈ 600 formas) está entrenado para asignar alta probabilidad a verbos en pasado simple tras "Yesterday I", así que el primer token decodificado es plausiblemente "worked", "played", "talked" — cualquier regular en pasado simple. La preferencia del modelo es irrelevante para esta fase; lo que importa es que en el paso t+1 las claves y valores para "Yesterday" y "I" son los mismos que en el paso t. Esa redundancia es lo que el caché existe para eliminar.

Lo que hace el decoding autorregresivo

Un LM causal genera un token a la vez. Tras t pasos el modelo ha producido los tokens \(x_1, x_2, \ldots, x_t\). Para producir \(x_{t+1}\):

  1. Embed de los \(t\) tokens.
  2. Pasarlos por \(L\) capas transformer. En cada capa, la atención calcula \(\text{softmax}(QK^\top / \sqrt{d_k}) V\) — una matriz \(t \times t\) de pesos de atención, por \(V\).
  3. Tomar el hidden state del último token, proyectar a vocabulario, samplear.

Ahora producir \(x_{t+2}\). La forma ingenua:

  1. Embed de los \(t+1\) tokens.
  2. Pasarlos por \(L\) capas. Cada atención ahora calcula una matriz \((t+1) \times (t+1)\).
  3. Tomar el hidden state del último token, samplear.

Fíjate. El bloque superior izquierdo \(t \times t\) de la matriz de atención \((t+1) \times (t+1)\) del paso 2 es idéntico a la matriz de atención completa del paso anterior. La hemos recalculado. La recalcularemos otra vez en el siguiente paso como el bloque superior izquierdo \((t+1) \times (t+1)\) de un \((t+2) \times (t+2)\). Y otra vez. Y otra.

Trabajo redundante total de atención sobre \(n\) tokens generados: \(\sum_{t=1}^{n} t^2 = \Theta(n^3)\). Cómputo lineal por token se vuelve cómputo cúbico total. Esa es la enfermedad.

El remedio

En cualquier capa transformer, la atención lee tres matrices:

  • \(Q = X W_Q\) — query, una fila por token actual.
  • \(K = X W_K\) — key.
  • \(V = X W_V\) — value.

En el paso \(t+1\), \(K\) y \(V\) para los tokens \(1..t\) son byte a byte idénticos a lo que eran en el paso \(t\) — porque las entradas \(x_1..x_t\) no han cambiado, los pesos no han cambiado, nada ha cambiado. Solo hace falta calcular una fila nueva: la fila de \(x_{t+1}\).

El caché guarda las filas de \(K\) y \(V\) de pasos previos. En el paso \(t+1\):

  1. Calcular la fila nueva de \(K\) y \(V\) solo para \(x_{t+1}\). Coste: \(O(d^2)\).
  2. Añadir esa fila nueva al caché.
  3. Calcular la atención como \(Q_\text{new} \cdot K_\text{cached}^\top\) — una fila \(1 \times (t+1)\), no un bloque \((t+1) \times (t+1)\). Coste: \(O((t+1) d)\).
  4. Softmax de esa fila, multiplicar por \(V_\text{cached}\), salida.

Cómputo total para el nuevo token: \(O(t \cdot d)\). Cómputo total para toda la secuencia: \(\sum_{t=1}^{n} t \cdot d = O(n^2 d)\). Cúbico se vuelve cuadrático.

Concretamente sobre "Yesterday I worked": el paso 1 (decode del tercer token) calcula \(q\) solo para el slot de la posición 2, proyecta \(x_{\text{"I"}}\) a una única fila nueva de \(K\) y \(V\), y escribe esa fila en el índice 1 del caché (base cero). Las filas del caché para "Yesterday" (índice 0) se escribieron durante el prefill y solo se tocan como lecturas. Ninguna multiplicación por \(W_K\) o \(W_V\) se repite para ellas. Esa es toda la optimización.

El precio

Las \(K, V\) cacheadas tienen que vivir en algún sitio. Tras \(S\) tokens generados, el caché contiene:

\[\text{bytes}_\text{cache} = 2 \cdot L \cdot H \cdot d_h \cdot S \cdot B \cdot s\]
  • \(2\): uno para K, uno para V.
  • \(L\): capas (caché por capa; las capas no comparten K, V).
  • \(H \cdot d_h = d\): cabezas por dim-de-cabeza igual a dim del modelo.
  • \(S\): longitud actual de la secuencia en tokens.
  • \(B\): tamaño del batch.
  • \(s\): bytes por elemento (4 para fp32, 2 para fp16/bf16, 1 para int8).

Esta fórmula se deriva en detalle en 02-memory-cost.md. Dos consecuencias que vale la pena absorber ya:

  1. El caché crece linealmente en \(S\). Doblar el contexto dobla la memoria del caché. El contexto largo es linealmente caro en memoria.
  2. El caché es enorme para modelos reales. Llama-2-7B a 4096 de contexto, fp16, batch 1: \(2 \times 32 \times 32 \times 128 \times 4096 \times 1 \times 2 = 2.15 \cdot 10^9\) bytes ≈ 2 GiB. Por secuencia. Por GPU. Por eso servir un modelo 7B con 16 usuarios concurrentes en una sola A100 es difícil aunque los pesos del modelo solo ocupen 14 GiB. En cambio, el MiniGPT de gramática (Fase 17 por defecto: 4 capas, 4 cabezas, \(d_h\) = 16, fp32) con nuestro prompt plausible más largo (ctx = 32) solo contiene \(2 \times 4 \times 4 \times 16 \times 32 \times 1 \times 4 \approx 16\) KiB — cabe en L1. El tamaño del caché no importa a nuestra escala; la fórmula sí, porque cada sistema sobre el que leas después asume que sabes aplicarla.

La dicotomía: prefill vs decode

El caché también reestructura la inferencia en dos fases distintas:

  • Prefill (también "context encoding", "prompt processing"). El usuario envía un prompt de longitud \(P\). Calculamos \(K, V\) para los \(P\) tokens en una pasada paralela — una sola atención \(P \times P\) por capa. Cómputo: \(O(L P^2 d)\). Memoria: \(O(L H d_h P) = O(L d P)\). El prefill está limitado por cómputo para \(P\) moderado — hay mucha aritmética por byte cargado.
  • Decode (también "generación", "decoding incremental"). Tras el prefill, producimos tokens uno a uno. Cada paso es una fila de atención \(1 \times S\) por capa, leyendo el caché entero. Cómputo por paso: \(O(L S d)\). Tráfico de memoria por paso: O(tamaño del caché) = \(O(L d S)\). Intensidad aritmética: \(O(d) / O(d) = O(1)\) — en realidad \(\sim 0.5\) FLOPs/byte. El decode está limitado por memoria, profundamente.

Estas dos fases quieren hardware distinto, schedulers distintos, optimizaciones distintas. Cada sistema de serving sobre el que leerás — vLLM, TensorRT-LLM, SGLang — es en el fondo una forma de mantener la GPU ocupada con trabajo de prefill mientras espera memoria para trabajo de decode. El caché es el artefacto que crea la asimetría.

Lo que el caché no es

Unas aclaraciones que evitan confusión más adelante:

  • El caché guarda \(K\) y \(V\), no \(Q\). \(Q\) se recalcula cada paso porque el nuevo token tiene su propia fila de query.
  • El caché es por capa. Las \(K, V\) cacheadas de la capa \(\ell\) no son reusables en la capa \(\ell+1\). El caché es una lista de \(L\) tensores.
  • El caché es por cabeza. Dentro de una capa, cada una de las \(H\) cabezas de atención tiene su propia K y V \((S, d_h)\). (Algunas arquitecturas comparten K, V entre cabezas — grouped query attention, multi-query attention. Son temas de la Fase 26 / 27. La Fase 22 asume multi-head completo con K, V por cabeza.)
  • El caché es por secuencia en el batch. El serving por lotes con secuencias de distinta longitud es difícil precisamente porque el caché de cada secuencia es de distinto tamaño. Este es el problema de fragmentación que PagedAttention resuelve (vista previa en 04-toward-paged-attention.md).
  • El caché no es "memoización" en el sentido Lisp. Es un anillo de tensores de estructura fija, no un hash map. Solo añade durante la generación; se limpia en el siguiente prompt.

Lo que esta página NO cubre

  • Derivación de la fórmula de bytes. Esbozada aquí; rigurosa en theory/02-memory-cost.md.
  • Por qué el decode está limitado por memoria en hardware real. Nombrado aquí; argumento de intensidad aritmética en theory/03-decode-as-memory-bound.md.
  • PagedAttention / serving por lotes de longitud variable. Nombrado aquí; solo enunciado del problema en theory/04-toward-paged-attention.md, derivación completa en la Fase 27.
  • Layout del caché específico de GPU. La Fase 22 corre en DRAM vía NumPy; el layout HBM / SRAM / registros es Fase 23–24.
  • Cuantización del caché (int8 / fp16). Fase 26.

Lo que deberías saber hacer al terminar esta fase

  1. Esbozar la dicotomía prefill/decode en una pizarra, con los costes asintóticos correctos etiquetados, usando "Yesterday I worked" como ejemplo recurrente.
  2. Derivar la fórmula de bytes desde primeros principios — es decir, contar lo que se guarda, no memorizar la fórmula.
  3. Predecir el tamaño del caché para cualquier modelo a partir de su config (Llama, Mistral, clase GPT) sin ejecutar nada.
  4. Explicar en una frase por qué el decode está limitado por memoria, y qué implica eso sobre qué optimización (Flash-decoding, paging, GQA, cuantización) ataca qué síntoma.

Si alguno de esos cuatro flaquea, el laboratorio lo cazará. Si los cuatro están claros, el laboratorio es en su mayoría mecánico.


Siguiente: theory/01-prefill-vs-decode.md — formalizar las dos fases y la contabilidad de FLOPs que justifica la dicotomía.