English · Español
Break 00 — Forzar GQA a kv_heads = 1 (colapso extremo a MQA)¶
🇪🇸 Forzamos
kv_heads = 1en el bloque de atención GQA: las 4 cabezas de Mini-GPT comparten una sola clave/valor. La KV cache se reduce 4×, pero todas las cabezas atienden con el mismoK/V— la diversidad de cabezas que entrenamos en Fase 15 colapsa. La PPL salta y las muestras se vuelven repetitivas.
Este ejercicio /break apunta al trade-off calidad vs memoria en GQA. El bug es un solo número; el fallo es un golpe medible en PPL más una pérdida cualitativa de diversidad entre cabezas.
Anclas: theory/04-gqa-mqa-mla.md, theory/05-flash-walkthrough-and-gqa-math.md, .claude/commands/break.md.
Hipótesis¶
El learner predice: "Poner kv_heads = 1 (MQA) en un modelo entrenado con n_heads = 4 elimina 3 de las 4 proyecciones K, V. Las cuatro cabezas de query comparten ahora un único K, V. El modelo pierde la capacidad de atender a posiciones distintas por cabeza; la única diversidad que queda es en Q y W_O. El PPL sube; la diversidad de muestras baja."
El break¶
En src/minimodel/blocks/attention_gqa.py:
class GQABlock(Module):
- def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2):
+ def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1): # /break: extreme MQA
super().__init__()
self.n_heads = n_heads
self.kv_heads = kv_heads
self.d_h = d_model // n_heads
self.W_q = Linear(d_model, n_heads * self.d_h, bias=False)
self.W_k = Linear(d_model, kv_heads * self.d_h, bias=False)
self.W_v = Linear(d_model, kv_heads * self.d_h, bias=False)
self.W_o = Linear(n_heads * self.d_h, d_model, bias=False)
Un número cambiado: kv_heads: 2 → 1. El forward ya maneja el broadcast de K, V entre grupos de query; no hace falta ninguna otra edición. Es el break más limpio posible para este concepto.
Predice, luego ejecuta¶
Para Mini-GPT (n_heads = 4, kv_heads = 1 = MQA), la KV cache por token baja de 2·n_layers·2·d_h·4 = 1024 B (GQA-2) a 2·n_layers·1·d_h·4 = 512 B (MQA). 2× más ahorro de memoria. Pero las 4 cabezas de query comparten ahora un único K, V, así que atienden a las mismas posiciones, ponderadas sólo por la proyección Q.
Predicciones¶
- PPL en el set de eval §A13: notablemente más alta que la baseline GQA-2. Estimación tosca: GQA-2 ≈ 5.20; MQA ≈ 5.55 (un golpe del 7% en PPL, vs ~1-2% de GQA-2 sobre MHA — el colapso extremo es mucho peor que el moderado).
- Diversidad de muestras: el mismo prompt muestreado 10 veces produce 6-7 de la misma continuación. El modelo se vuelve mucho más "seguro" de forma degenerada porque la mayoría de las cabezas de query ven el mismo patrón de attention.
- Entropía de attention por cabeza: para las 4 cabezas del bloque 1, la entropía de la distribución de attention converge dentro del ~50% del mismo valor (vs entropías variadas en la baseline GQA-2).
- Bytes de KV cache: exactamente la mitad de la baseline GQA-2. La ganancia de memoria es real y no está en duda.
Escribe las predicciones en learners/borja/phase-27/notes/breaks.md antes de ejecutar.
Observa¶
Ejecuta el eval de la Fase 27 con la config rota:
Diagnósticos a plotear:
- Gráfico de barras de PPL:
mha(baseline),gqa-2(default de la Fase 27),mqa-kv1(este break). Tres barras; la de la derecha debería ser la más alta. - Entropía de attention por cabeza en la capa 1: plot de 4 líneas, una por cabeza de query. En MHA/GQA-2 las líneas divergen; en MQA las 4 líneas colapsan unas sobre otras tras las primeras posiciones (porque comparten
K). - Muestrea 10 continuaciones de "I work" a temperatura 0.8; cuenta salidas distintas. Baseline ≈ 7; roto ≈ 3.
Síntoma que verá Borja¶
- El PPL salta un 5-10% en el set de eval §A13.
- La diversidad de muestras colapsa (la mayoría de las continuaciones idénticas).
- Los bytes de KV cache bajan exactamente 2× comparados con GQA-2 (la ganancia de memoria es real — ese es el trade-off que se ilustra).
- No hay crash en tiempo de entrenamiento; es un cambio arquitectónico en tiempo de inferencia.
Causa oculta (una frase)¶
Poner kv_heads = 1 hace que las n_heads cabezas de query compartan una única proyección K, V, eliminando la diversidad de patrones de attention por cabeza y colapsando el modelo a la capacidad de routing de una única cabeza efectiva.
Cascada de pistas¶
- Plotea la entropía de attention por cabeza en la capa 1. ¿Siguen siendo distinguibles las cabezas?
- Imprime
self.W_k.weight.shape. ¿Cuántas proyecciones de clave independientes tiene realmente el modelo? - Relee
theory/04-gqa-mqa-mla.md§"GQA quality scaling". ¿Qué encuentra Ainslie et al. como ratiokv_heads / n_headsseguro? ¿Qué valor pusiste?
Diff de fix¶
-class GQABlock(Module):
- def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1):
+class GQABlock(Module):
+ def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2): # restored: GQA-2 is the default
O — si exploras más — vuelve a ejecutar con kv_heads = 4 (MHA completo) y kv_heads = 2 (GQA-2) para trazar la frontera de Pareto real en Mini-GPT.
Por qué enseña el concepto¶
El paper original de GQA (Ainslie et al., 2023) reporta que kv_heads = n_heads / 8 es el punto de operación seguro — por debajo, la calidad se degrada de forma visible. En Mini-GPT, ese ratio es 4/8 = 0.5, así que incluso GQA-2 (kv_heads = 2) está en la frontera. Empujar a kv_heads = 1 es pasarse del punto seguro, a propósito. Es el tipo de trade-off donde la línea abstracta "MQA ahorra memoria" en un paper de survey esconde un coste real en calidad. Ver saltar el número de PPL sobre un modelo que entrenaste tú hace el trade-off concreto de un modo que ningún survey logra.
Referencias¶
- Ainslie et al., GQA (arXiv:2305.13245), Figura 3 — calidad vs
kv_heads / n_heads. - Shazeer, Fast Transformer Decoding: One Write-Head is All You Need (arXiv:1911.02150) — el paper original de MQA; reporta un golpe en calidad pequeño pero real en T5.
Siguiente: restaura GQA-2 y ejecuta lab/04-mqa-gqa.md para la ablación completa.