Skip to content

English · Español

02 — Coste de memoria del KV Cache

🇪🇸 La fórmula bytes = 2 · L · H · d_h · S · B · s no es una receta a memorizar — es una cuenta de "qué guardo y para cuántos". Derivarla cada vez te protege contra los errores de "olvidé el factor 2" o "confundí d con d_h", que son los dos errores más comunes en este tema.

Este es el eje algebraico de la Fase 22. Derivamos la fórmula de tamaño del caché desde primeros principios, luego la aplicamos a modelos reales, luego derivamos qué te compra cambiar cada factor (L, H, d_h, S, B, s).


La fórmula, derivada contando

El caché guarda, por capa transformer, dos tensores: \(K\) y \(V\). Tras procesar \(S\) tokens, cada uno tiene forma:

\[K, V \in \mathbb{R}^{B \times H \times S \times d_h}\]
  • \(B\): tamaño de batch (número de secuencias concurrentes que comparten esta instancia de caché).
  • \(H\): número de cabezas de atención en esta capa.
  • \(S\): número de tokens cuyas K, V están actualmente guardadas.
  • \(d_h\): dimensión de cabeza (nota: \(H \cdot d_h = d\), la dim del modelo).

El número de elementos en \(K\) solo es \(B \cdot H \cdot S \cdot d_h\). En \(V\) es el mismo. Por capa, \(K + V\) juntos contienen \(2 B H S d_h\) elementos.

Tenemos \(L\) capas, cada una con su propio caché independiente. Las capas no comparten K, V — las K, V de cada capa se producen por las proyecciones \(W_K, W_V\) de esa capa a partir de la entrada de esa capa. Elementos totales:

\[\text{elements} = 2 \cdot L \cdot B \cdot H \cdot S \cdot d_h\]

Multiplicar por bytes por elemento \(s\):

\[\boxed{\text{bytes}_\text{cache} = 2 \cdot L \cdot H \cdot d_h \cdot S \cdot B \cdot s}\]

Convencionalmente se escribe con \(L, H, d_h\) delante porque son constantes de arquitectura del modelo, y \(S, B\) al final porque son palancas en tiempo de ejecución. \(s\) depende del dtype:

dtype \(s\)
fp64 8
fp32 4
fp16 / bf16 2
int8 1
int4 (empaquetado) 0.5

Sanity check: una derivación alternativa

Algunas referencias escriben la fórmula con la dim del modelo \(d\) en vez de \(H \cdot d_h\):

\[\text{bytes}_\text{cache} = 2 \cdot L \cdot d \cdot S \cdot B \cdot s \quad \text{(ya que } d = H \cdot d_h\text{)}\]

Ambas son correctas. La forma \(H \cdot d_h\) es útil al razonar sobre Grouped-Query Attention (Fase 27), donde K y V se comparten entre grupos de cabezas: GQA cambia \(H\) en el caché (al número de cabezas clave-valor, \(H_{KV} < H\)) pero no lo cambia en \(Q\) (siguen siendo \(H\) cabezas). Así que la fórmula del caché se convierte en \(2 L H_{KV} d_h S B s\), mientras que el cómputo de atención sigue usando \(H\) cabezas. La forma con \(d\) oculta esto; la forma \(H \cdot d_h\) lo hace visible.

Coste marginal por token

¿Cuántos bytes añade cada token generado adicional al caché?

Diferenciar (bueno — tomar la diferencia, ya que \(S\) es discreto):

\[\Delta\text{bytes} = 2 \cdot L \cdot H \cdot d_h \cdot B \cdot s \quad \text{(constante en } S \text{)}\]

El caché crece en un número constante de bytes por token — no depende del tamaño actual del caché. Esto es lo que significa cuantitativamente "memoria lineal en contexto".

Para Llama-2-7B (fp16, batch 1): $\(\Delta\text{bytes} = 2 \cdot 32 \cdot 32 \cdot 128 \cdot 1 \cdot 2 = 524288 = 512 \text{ KiB por token}\)$

Medio megabyte por token. Un contexto de 4096 tokens = 2 GiB.

Una tabla de escalado: tamaño del caché para modelos reales

Memoriza la forma de esta tabla; aparece en cada paper de sistema de serving.

Modelo \(L\) \(H\) \(d_h\) \(d\) dtype Por token ctx 4k ctx 32k ctx 128k
MiniGPT de gramática (Fase 17 por defecto, §A13) 4 4 16 64 fp32 512 B 2 MiB 16 MiB 64 MiB
GPT-2 small 12 12 64 768 fp16 36 KiB 144 MiB 1.1 GiB 4.5 GiB
Llama-2-7B 32 32 128 4096 fp16 512 KiB 2 GiB 16 GiB 64 GiB
Llama-2-13B 40 40 128 5120 fp16 800 KiB 3.1 GiB 25 GiB 100 GiB
Llama-2-70B (MHA, sin GQA — contrafactual) 80 64 128 8192 fp16 2.5 MiB 10 GiB 80 GiB 320 GiB
Llama-2-70B (GQA, \(H_{KV}=8\)) 80 8 128 8192 fp16 320 KiB 1.25 GiB 10 GiB 40 GiB
GPT-3 175B 96 96 128 12288 fp16 4.5 MiB 18 GiB 144 GiB 576 GiB

(Fila de MiniGPT: derivable de la config de la Fase 17 — confirmar al abrir la fase por si la config cambió.)

Cosas que observar al leer la tabla:

  1. GQA no es un retoque. Pasar de 70B-MHA-contrafactual a 70B-GQA encoge el caché 8×. Sin GQA, servir Llama-2-70B con contexto largo en una sola H100 es imposible. Esta es una de las tres razones por las que GQA existe; las otras dos son la latencia de inferencia (menos que leer cada paso) y el aprovechamiento del ancho de banda de memoria.
  2. El contexto largo no es "gratis con suficiente RAM". 128k de contexto en un modelo 70B son 40 GiB por secuencia. Una sola A100 (80 GB) contiene los pesos del modelo y el caché de una secuencia — a duras penas. Los usuarios concurrentes rompen esto.
  3. El caché del MiniGPT de gramática es trivial. Ese es el punto: la Fase 22 se queda en una escala donde Borja puede verificar cada byte a mano. La frase realista más larga del corpus §A13 son ~10 tokens ("Tomorrow he is going to study and finish"); el caché a esa longitud está muy por debajo de 1 KiB total. La Fase 24 pasa a una escala donde la medición reemplaza a la enumeración.

Dónde vive el caché en memoria

Opciones de implementación:

  1. Un gran tensor por capa, pre-asignado a S_max. Forma \((B, H, S_\text{max}, d_h)\). Escribir en slices [..., :S, :] a medida que \(S\) crece. Memoria: constante. Fragmentación: ninguna (un bloque contiguo por capa). Bytes desperdiciados: \(S_\text{max} - S_\text{current}\) filas por capa. Esto es lo que implementa la Fase 22.
  2. Una lista de tensores que crecen. Cada cache.append(k) hace K = np.concatenate([K, k_new]). Memoria: variable. Fragmentación: thrash del heap. Coste por append: O(S) — destruye el decode lineal por paso. No lo hagas.
  3. Paginado: una lista de bloques de tamaño fijo por secuencia, indexada por una "block table". Memoria: constante por bloque, variable en nº de bloques. Fragmentación: solo en los límites de bloque (pequeña). Esto es PagedAttention; Fase 27.

La Fase 22 usa (1). Los bytes desperdiciados son el precio de mantener la implementación pequeña y la matemática transparente. Con \(S_\text{max} = 64\) en el MiniGPT de gramática, el caché son 32 KiB total — cabe en L1. No nos importa.

Dos ecuaciones que interiorizar

Ambas siguen trivialmente de la fórmula enmarcada arriba; ambas deberían estar en la punta de la lengua.

1. El tamaño del caché se duplica cuando doblas el contexto. $\(\text{bytes}(2S) = 2 \cdot \text{bytes}(S)\)$ "Pasar de 4k a 8k de contexto dobla el caché" no es una observación empírica; es álgebra. Quien lo diga casualmente como si fuera una sorpresa no ha interiorizado la fórmula.

2. El tamaño del caché a presupuesto de bytes fijo da un techo de contexto. $\(S_\text{max} = \frac{\text{bytes}_\text{budget}}{2 \cdot L \cdot H \cdot d_h \cdot B \cdot s}\)$

Para Llama-2-7B fp16 en una A100 de 40 GB con 14 GiB ocupados por pesos, dejando 26 GiB para caché, batch 1: $\(S_\text{max} = \frac{26 \cdot 2^{30}}{2 \cdot 32 \cdot 32 \cdot 128 \cdot 1 \cdot 2} = \frac{27.9 \cdot 10^9}{524288} \approx 53200 \text{ tokens}\)$

Así que una sola A100 puede servir Llama-2-7B a ~53k de contexto, un usuario. Sube el batch a 16, y \(S_\text{max}\) cae a ~3300 tokens. Ese es el compromiso exacto que navegan los sistemas de serving.

Problemas de práctica

Soluciones en solutions/02-memory-cost-ref.md (no visible durante el pre-escrito). Trabájalas antes del laboratorio.

  1. Llama-2-7B, fp16, batch=4, S=8192. ¿Tamaño del caché en GiB?
  2. Mistral-7B usa GQA con \(H_{KV}=8\), por lo demás misma config que Llama-2-7B. Misma situación (batch=4, S=8192): ¿tamaño del caché en GiB?
  3. Cuantizar el caché a int8. Rehaz (1). ¿Cómo afecta esto a la precisión (cualitativo)? (Pista: K y V son post-rotary, post-proyección — son activaciones, no pesos. El ruido de cuantización se compone a lo largo de las capas.)
  4. Atención de ventana deslizante mantiene solo los últimos \(W = 1024\) tokens del caché. Llama-2-7B fp16, batch=1, ctx=32k. ¿Tamaño del caché?
  5. Invertir la fórmula. Tienes una GPU de 24 GiB; 10 GiB son pesos del modelo y activaciones forward. Tu modelo es GPT-2 small (config arriba), fp16, batch=8. ¿Cuál es el contexto máximo \(S_\text{max}\)?

Si esos cinco son mecánicos para ti, la fórmula ha aterrizado.

Lo que esta página NO cubre

  • Por qué el decode toca el techo de memoria en hardware real. Argumento de intensidad aritmética en theory/03-decode-as-memory-bound.md.
  • Layout de memoria del caché paginado. theory/04-toward-paged-attention.md previsualiza; la Fase 27 implementa.
  • Numéricos del caché int8 / fp16. La fórmula de bytes escala linealmente vía \(s\); el impacto en precisión es la Fase 26 (cuantización).
  • Techos de ancho de banda de memoria de GPU (HBM vs SRAM vs registros). Fase 23. La fórmula de tamaño es independiente del hardware; lo rápido que puedes leerlo no.

Siguiente: theory/03-decode-as-memory-bound.md — la intensidad aritmética de la atención de decode, y por qué decodificar desde caché es un problema de ancho de banda de memoria.