Skip to content

English · Español

Break 00 — Forzar GQA a kv_heads = 1 (colapso extremo a MQA)

🇪🇸 Forzamos kv_heads = 1 en el bloque de atención GQA: las 4 cabezas de Mini-GPT comparten una sola clave/valor. La KV cache se reduce 4×, pero todas las cabezas atienden con el mismo K/V — la diversidad de cabezas que entrenamos en Fase 15 colapsa. La PPL salta y las muestras se vuelven repetitivas.

Este ejercicio /break apunta al trade-off calidad vs memoria en GQA. El bug es un solo número; el fallo es un golpe medible en PPL más una pérdida cualitativa de diversidad entre cabezas.

Anclas: theory/04-gqa-mqa-mla.md, theory/05-flash-walkthrough-and-gqa-math.md, .claude/commands/break.md.


Hipótesis

El learner predice: "Poner kv_heads = 1 (MQA) en un modelo entrenado con n_heads = 4 elimina 3 de las 4 proyecciones K, V. Las cuatro cabezas de query comparten ahora un único K, V. El modelo pierde la capacidad de atender a posiciones distintas por cabeza; la única diversidad que queda es en Q y W_O. El PPL sube; la diversidad de muestras baja."

El break

En src/minimodel/blocks/attention_gqa.py:

 class GQABlock(Module):
-    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2):
+    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1):  # /break: extreme MQA
         super().__init__()
         self.n_heads = n_heads
         self.kv_heads = kv_heads
         self.d_h = d_model // n_heads
         self.W_q = Linear(d_model, n_heads * self.d_h, bias=False)
         self.W_k = Linear(d_model, kv_heads * self.d_h, bias=False)
         self.W_v = Linear(d_model, kv_heads * self.d_h, bias=False)
         self.W_o = Linear(n_heads * self.d_h, d_model, bias=False)

Un número cambiado: kv_heads: 2 → 1. El forward ya maneja el broadcast de K, V entre grupos de query; no hace falta ninguna otra edición. Es el break más limpio posible para este concepto.

Predice, luego ejecuta

Para Mini-GPT (n_heads = 4, kv_heads = 1 = MQA), la KV cache por token baja de 2·n_layers·2·d_h·4 = 1024 B (GQA-2) a 2·n_layers·1·d_h·4 = 512 B (MQA). 2× más ahorro de memoria. Pero las 4 cabezas de query comparten ahora un único K, V, así que atienden a las mismas posiciones, ponderadas sólo por la proyección Q.

Predicciones

  • PPL en el set de eval §A13: notablemente más alta que la baseline GQA-2. Estimación tosca: GQA-2 ≈ 5.20; MQA ≈ 5.55 (un golpe del 7% en PPL, vs ~1-2% de GQA-2 sobre MHA — el colapso extremo es mucho peor que el moderado).
  • Diversidad de muestras: el mismo prompt muestreado 10 veces produce 6-7 de la misma continuación. El modelo se vuelve mucho más "seguro" de forma degenerada porque la mayoría de las cabezas de query ven el mismo patrón de attention.
  • Entropía de attention por cabeza: para las 4 cabezas del bloque 1, la entropía de la distribución de attention converge dentro del ~50% del mismo valor (vs entropías variadas en la baseline GQA-2).
  • Bytes de KV cache: exactamente la mitad de la baseline GQA-2. La ganancia de memoria es real y no está en duda.

Escribe las predicciones en learners/borja/phase-27/notes/breaks.md antes de ejecutar.

Observa

Ejecuta el eval de la Fase 27 con la config rota:

just exp 27-gqa --variant mqa-kv1

Diagnósticos a plotear:

  1. Gráfico de barras de PPL: mha (baseline), gqa-2 (default de la Fase 27), mqa-kv1 (este break). Tres barras; la de la derecha debería ser la más alta.
  2. Entropía de attention por cabeza en la capa 1: plot de 4 líneas, una por cabeza de query. En MHA/GQA-2 las líneas divergen; en MQA las 4 líneas colapsan unas sobre otras tras las primeras posiciones (porque comparten K).
  3. Muestrea 10 continuaciones de "I work" a temperatura 0.8; cuenta salidas distintas. Baseline ≈ 7; roto ≈ 3.

Síntoma que verá Borja

  • El PPL salta un 5-10% en el set de eval §A13.
  • La diversidad de muestras colapsa (la mayoría de las continuaciones idénticas).
  • Los bytes de KV cache bajan exactamente 2× comparados con GQA-2 (la ganancia de memoria es real — ese es el trade-off que se ilustra).
  • No hay crash en tiempo de entrenamiento; es un cambio arquitectónico en tiempo de inferencia.

Causa oculta (una frase)

Poner kv_heads = 1 hace que las n_heads cabezas de query compartan una única proyección K, V, eliminando la diversidad de patrones de attention por cabeza y colapsando el modelo a la capacidad de routing de una única cabeza efectiva.

Cascada de pistas

  1. Plotea la entropía de attention por cabeza en la capa 1. ¿Siguen siendo distinguibles las cabezas?
  2. Imprime self.W_k.weight.shape. ¿Cuántas proyecciones de clave independientes tiene realmente el modelo?
  3. Relee theory/04-gqa-mqa-mla.md §"GQA quality scaling". ¿Qué encuentra Ainslie et al. como ratio kv_heads / n_heads seguro? ¿Qué valor pusiste?

Diff de fix

-class GQABlock(Module):
-    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 1):
+class GQABlock(Module):
+    def __init__(self, d_model: int, n_heads: int = 4, kv_heads: int = 2):  # restored: GQA-2 is the default

O — si exploras más — vuelve a ejecutar con kv_heads = 4 (MHA completo) y kv_heads = 2 (GQA-2) para trazar la frontera de Pareto real en Mini-GPT.

Por qué enseña el concepto

El paper original de GQA (Ainslie et al., 2023) reporta que kv_heads = n_heads / 8 es el punto de operación seguro — por debajo, la calidad se degrada de forma visible. En Mini-GPT, ese ratio es 4/8 = 0.5, así que incluso GQA-2 (kv_heads = 2) está en la frontera. Empujar a kv_heads = 1 es pasarse del punto seguro, a propósito. Es el tipo de trade-off donde la línea abstracta "MQA ahorra memoria" en un paper de survey esconde un coste real en calidad. Ver saltar el número de PPL sobre un modelo que entrenaste tú hace el trade-off concreto de un modo que ningún survey logra.

Referencias

  • Ainslie et al., GQA (arXiv:2305.13245), Figura 3 — calidad vs kv_heads / n_heads.
  • Shazeer, Fast Transformer Decoding: One Write-Head is All You Need (arXiv:1911.02150) — el paper original de MQA; reporta un golpe en calidad pequeño pero real en T5.

Siguiente: restaura GQA-2 y ejecuta lab/04-mqa-gqa.md para la ablación completa.