English · Español
04 — GQA, MQA, MLA: compartir K y V¶
🇪🇸 Tres trucos para reducir el tamaño del KV cache compartiendo K y V entre cabezas. MQA: una sola K/V para todas. GQA: una K/V por grupo. MLA: K/V comprimidos a un espacio latente. Cada uno es una intervención sobre los bytes que cruzan memoria por token decodificado.
El KV cache como un término del roofline¶
Durante el decode autoregresivo (un token a la vez), la attention por paso funciona así:
- El
Qdel nuevo token tiene forma(1, n_heads, d). - Los
K, Vcacheados tienen forma(N, n_heads, d)dondeNes la longitud de secuencia actual. - Para cada cabeza, calcula
softmax((Q · K^T) / √d) · V.
El compute es O(N · d) por cabeza por paso — minúsculo. El cuello de botella es cargar K y V desde HBM. Para batch=1, la attention de decode de un solo token es extremadamente bandwidth-bound: la FPU hace quizá 2 N d FLOPs mientras 8 N d bytes (fp16) se mueven desde HBM. Intensidad I ≈ 0.25 FLOPs/byte. Memory-bound por órdenes de magnitud.
El tamaño del KV cache es el tráfico dominante de memoria por paso.
Si reducimos el KV cache, los bytes por paso caen proporcionalmente, la intensidad sube proporcionalmente, los tokens/seg de decode suben proporcionalmente. Esta es toda la premisa de GQA/MQA/MLA.
Multi-Query Attention (MQA, Shazeer 2019)¶
El cambio: todas las cabezas de attention comparten un único par K, V.
- MHA estándar:
K, V ∈ ℝ^{N × n_heads × d}. Bytes del KV cache =2 × N × n_heads × d × 2 (fp16) = 4 N · n_heads · d. - MQA:
K, V ∈ ℝ^{N × 1 × d}. Bytes del KV cache =4 N d.n_heads× más pequeño.
La matemática: cada cabeza de query i calcula softmax(Q_i K^T / √d) V, donde K, V son los mismos para todo i. La flexibilidad de K independiente por cabeza se cede; a cambio, el KV cache encoge por el número de cabezas (típicamente 32×).
Coste de calidad: los modelos MQA suelen entrenarse desde cero (no re-adaptados desde MHA). La calidad es ligeramente peor que MHA al mismo número de parámetros — pero el modelo puede hacerse más grande con la memoria ahorrada, así que el trato es net-positivo en la práctica. PaLM, Falcon y otros usan MQA.
Grouped-Query Attention (GQA, Ainslie 2023)¶
El cambio: las cabezas se agrupan; cada grupo comparte un par K, V.
n_groups = n_heads / group_size. Típicamentegroup_size = 8→ 4× de reducción.K, V ∈ ℝ^{N × n_groups × d}. Bytes del KV cache =4 N · n_groups · d.group_size× más pequeño.
GQA es una mejora Pareto en el espectro MHA-MQA. Con 32 cabezas y group_size=8, el KV cache es 4× más pequeño (entre el 32 de MHA y el 1 de MQA), y la calidad está más cerca de MHA. LLaMA 2 70B usa GQA con group_size=8. Mistral 7B usa GQA con group_size=4.
Receta de re-adaptación: para convertir un modelo entrenado con MHA a GQA, promedia los pesos de proyección K y V dentro de cada grupo, luego continúa entrenando unos pocos pasos para recuperar calidad. Barato.
Multi-Latent Attention (MLA, DeepSeek-V2 2024)¶
El cambio: proyecta K, V a un espacio latente de bajo rango; cachea el latente; reconstruye al vuelo.
- Estado KV por token: un único vector de bajo rango
c ∈ ℝ^{d_c}cond_c < d × n_heads. - En tiempo de attention, reconstruye
K, Vdesdecvía matrices de proyección aprendidas.
El KV cache se vuelve N × d_c en vez de N × n_heads × d. Para DeepSeek-V2: d_c ≈ 512, n_heads × d ≈ 16384. 32× de reducción vs MHA. Sustancialmente mejor que la "1 cabeza KV" de MQA porque el espacio latente es de alta calidad (elegido por entrenamiento, no promedio arbitrario).
Coste: la proyección de reconstrucción añade compute por llamada de attention. Matemáticamente, MLA compone con Flash — la proyección es un matmul adicional pequeño dentro del kernel. El paper de DeepSeek reporta overhead despreciable.
Calidad: MLA iguala o supera a MHA al mismo presupuesto efectivo de compute mientras tiene KV 30× más pequeño. El paper de MLA es uno de los papers de attention más importantes de 2024.
Dónde componen estos con Flash¶
FlashAttention tilea Q, K, V y procesa attention en SRAM. GQA/MQA/MLA reducen K, V. Componen libremente: Flash con GQA = Flash donde el bucle interno de cada tile-Q accede a tiles K/V compartidos (con K/V indexados por grupo, no por cabeza). Las implementaciones estándar (FlashAttn-v2) lo soportan.
Dónde componen estos con Paged¶
PagedAttention pagina KV por (layer, head, position). Con GQA, la paginación es por grupo, no por cabeza. Con MLA, la paginación es sobre los vectores latentes, no K/V. La lógica del block_manager generaliza; el soporte de MLA en vLLM es reciente (mediados de 2024 en adelante).
Números del roofline¶
Para un modelo 7B en decode con N=4096:
| Variante | Heads × d_per | Bytes KV/token | Proxy de tokens/seg de decode |
|---|---|---|---|
| MHA | 32 × 128 = 4096 | 16 KiB | 1× baseline |
| GQA-8 | 32 cabezas de query, 8 grupos KV (4 cabezas de query comparten cada par KV), d=128 | 2 KiB | ~7–8× |
| MQA | 1 × 128 = 128 | 0.5 KiB | ~30× |
| MLA (d_c=512) | latente 512 | 1 KiB | ~16× |
(El "proxy de tokens/seg" es el inverso de bytes por token, asumiendo decode bandwidth-bound e ignorando compute por paso. Las medidas reales varían, pero el orden es correcto.)
Sliding window vs reducción de KV — la pregunta equivocada¶
Una confusión común: ¿no son sliding window y reducción de KV lo mismo? Reducen KV.
No:
- Sliding window reduce el número de posiciones atendidas: cada query ve W posiciones, no N.
- GQA/MQA/MLA reducen el tamaño KV por posición: cada posición contribuye con menos bytes.
Son ejes ortogonales:
| Posiciones completas | Sliding window W | |
|---|---|---|
| MHA estándar | N · n_heads · d KV | W · n_heads · d KV |
| MQA | N · d KV | W · d KV |
| MLA | N · d_c KV | W · d_c KV |
Mistral 7B usa GQA-4 + sliding window 4096. Juntos: KV por paso ~16× más pequeño que LLaMA vanilla, más attention O(W) en vez de O(N).
Lo que medimos en el roofline de Borja¶
Para el overlay del roofline (experimento 27-roofline-overlay), graficamos:
- MHA naive + attention naive: menor intensidad, en el techo de memoria.
- MHA naive + Flash: mayor intensidad, parte del camino arriba.
- GQA + Flash: aún mayor intensidad (menos bytes-por-K-tile cargados).
- MQA + Flash: mayor intensidad de un único kernel.
Este es un análisis de un solo kernel. Las ganancias a nivel de sistema de PagedAttention requieren mediciones a nivel de servidor fuera de alcance para la Fase 27.
Problemas de práctica¶
Soluciones al abrir la fase en solutions/04-gqa-mqa-mla-ref.md.
- Para MHA con 32 cabezas, GQA group_size=8. La forma de la matriz de pesos de proyección K cambia de
(hidden, n_heads × d)a(hidden, n_groups × d). ¿Por qué factor encogen los parámetros de proyección de K y V? - El KV cache de MQA es
4 N dbytes (fp16). Para LLaMA-7B (d=128, n_layers=32, N=4096), calcula el tamaño del KV cache. Compara con MHA (n_heads=32). - La reconstrucción de MLA
K = W_K_up · c, dondec ∈ ℝ^{d_c},K ∈ ℝ^{n_heads × d},W_K_up ∈ ℝ^{(n_heads d) × d_c}. Calcula los FLOPs extra por paso de attention de la reconstrucción. Argumenta cuándo esto es despreciable vs el compute dominante. - GQA-8 reduce el KV cache 4× vs MHA. Muestra que en un régimen de decode bandwidth-bound esto implica ~4× de mejora de throughput, ignorando otros términos.
Recap de un párrafo¶
GQA, MQA y MLA reducen todos el tamaño del KV cache compartiendo o comprimiendo K y V entre cabezas. MQA comparte un K/V entre todas las cabezas (compresión máxima, coste de calidad máximo). GQA comparte dentro de grupos (sweet spot de Pareto). MLA comprime a un espacio latente de bajo rango (mejor calidad con fuerte compresión). Cada uno compone con Flash y con PagedAttention, multiplicando las ganancias de throughput. La lente unificadora es el roofline: el decode es bandwidth-bound por el tráfico KV, así que reducir KV directamente sube la intensidad. Los motores de inferencia modernos apilan Flash + Paged + GQA + sliding window simultáneamente, cada uno atacando un coste distinto. Con la teoría de la Fase 27 completa, los labs implementan Flash forward en Triton y anotan el asignador del KV cache de vLLM.
Siguiente: lab/00-online-softmax.md.