Skip to content

English · Español

Lab 04 — MQA / GQA y la ganancia en KV cache

Objetivo: implementar multi-query attention (MQA) y grouped-query attention (GQA) como variantes drop-in de la attention de MiniGPT, y medir la reducción de tamaño del KV cache sobre la secuencia de 64 tokens del corpus de verbos.

Tiempo estimado: 4–6 horas.

Prerrequisito: theory 04 leída; src/minimodel/attention.py (Fase 15) familiar; KV cache de la Fase 22 comprendido.


Lo que produces

src/minimodel/attention_mqa_gqa.py — el módulo de attention variante.

experiments/27-mqa-gqa-kv-cache/ que contenga:

  • bench.py — ejecuta el forward de MHA / GQA / MQA sobre la misma secuencia de 64 tokens del corpus de verbos y mide: (a) bytes de KV cache por token generado, (b) deriva del PPL de salida vs la referencia MHA.
  • results.json — medidas.
  • kv_bytes.png — gráfico de barras, bytes por token para cada variante.
  • manifest.json.
  • README.md — interpretación.

Tests en tests/test_minimodel_mqa_gqa.py (Claude scaffoldea los que fallan).

La matemática (recapitulación de theory 04)

La multi-head attention (MHA) estándar tiene n_heads proyecciones independientes de Q, K, V. KV cache por token = 2 × n_heads × d_head × bytes_per_element.

Multi-query attention (MQA): todas las cabezas comparten un K y un V. Cache = 2 × 1 × d_head × bytes_per_element. Reducción n_heads×.

Grouped-query attention (GQA): las cabezas se reparten en n_kv_heads grupos; cada grupo comparte K y V. Cache = 2 × n_kv_heads × d_head × bytes_per_element. Reducción n_heads / n_kv_heads×.

Para MiniGPT con n_heads = 4: - Cache MHA por token = 2 × 4 × d_head × 2 bytes (fp16) = 16 × d_head bytes. - GQA con n_kv_heads = 2: 2 × 2 × d_head × 2 = 8 × d_head bytes. Reducción 2×. - MQA: 2 × 1 × d_head × 2 = 4 × d_head bytes. Reducción 4×.

Para d_head = 16 (típico de MiniGPT), el ahorro de cache por token es: 256 → 128 → 64 bytes. Sobre una secuencia de 64 tokens, cache total: 16 KiB → 8 KiB → 4 KiB. Números absolutos pequeños, pero el ratio es lo que generaliza a producción.

TODOs

Bloque A — diseñar el módulo

  • class AttentionVariant(nn.Module) acepta un objeto de configuración con d_model, n_heads, n_kv_heads, d_head, dropout, causal.
  • Cuando n_kv_heads == n_heads, se comporta idéntico a MHA (Fase 15). Cuando n_kv_heads == 1, es MQA. En otros casos, GQA.
  • La proyección Q produce n_heads × d_head por token; las proyecciones K y V producen n_kv_heads × d_head por token.
  • Attention: cada cabeza Q consulta el K, V del grupo correspondiente (broadcast / repeat-interleave de las cabezas KV para casar con las cabezas Q dentro del matmul). Mantente numéricamente equivalente a MHA cuando n_kv_heads == n_heads.

Bloque B — implementar y testear

  • Forma del forward: (B, T, d_model) → (B, T, d_model). Soporte de máscara causal.
  • Test: con n_kv_heads = n_heads, la salida coincide con src/minimodel/attention.py (Fase 15) a 1e-5.
  • Test: con n_kv_heads = 1, el número de parámetros baja en (2 × (n_heads - 1) × d_head × d_model) (el ahorro en pesos de las proyecciones K y V).
  • Test: el tensor de KV cache asignado por el módulo tiene la forma esperada (B, T, n_kv_heads, d_head) para K y la misma para V.

Bloque C — medir sobre la secuencia del corpus de verbos

  • Carga MiniGPT (checkpoint de la Fase 17) más una frase held-out del corpus de verbos de longitud 64 tokens.
  • Construye tres variantes de MiniGPT, intercambiando sólo el módulo de attention: MHA / GQA (n_kv_heads = 2) / MQA (n_kv_heads = 1).
  • Para cada una: ejecuta forward; mide los bytes totales del KV cache; calcula el PPL de salida sobre la distribución del siguiente token.
  • Nota: el PPL probablemente se degradará bajo MQA/GQA porque el compartir cabezas cambia la capacidad representacional del modelo. Es esperado. La Fase 28 (LoRA) es donde re-tunearías las variantes para recuperar; este lab sólo mide el trade-off.

Bloque D — interpretar en README.md

Cuatro preguntas:

  1. ¿Qué reducción de tamaño del KV cache mediste para MQA vs MHA? Espera ≈ n_heads× (4× para n_heads=4).
  2. ¿Cuál es la diferencia de PPL MQA vs MHA en el eval del corpus de verbos? Espera una degradación de 5–20% si el modelo no se entrenó pensando en MQA (que es el caso aquí — estamos intercambiando en tiempo de inferencia).
  3. ¿A qué escala de producción se vuelve MQA un no-brainer? Calcula el KV cache para un modelo de 32 capas / 32 cabezas / 128-d_head con longitud de secuencia 8192. Compara los bytes totales de cache MHA vs MQA. El número debería estar en GiB; esa es la respuesta.
  4. ¿Cuál es la relación entre MQA y flash attention? ¿Son ortogonales? ¿Componibles? (Pista: sí a ambas — MQA reduce los bytes K/V, flash reduce la matriz S materializada. Cuellos de botella diferentes.)

Restricciones

  • Sin reentrenar. Es un intercambio estructural en tiempo de inferencia. Mide la degradación con honestidad.
  • CPU vale. La secuencia de verbos es corta; no hace falta GPU para este lab. (Los labs GPU de la Fase 27 son el 01 y el 02.)
  • El mismo checkpoint de MiniGPT para las tres variantes — sólo cambia el módulo de attention.

🇪🇸 MQA/GQA no son aceleraciones gratis — comparten K/V entre cabezas y eso reduce capacidad. La elección real es: ¿cuánto PPL puedes ceder a cambio de cuánto KV cache? En modelos pequeños tipo MiniGPT, MQA es agresivo. En 70B+, casi todos los modelos modernos usan GQA.

Condiciones de parada

  • src/minimodel/attention_mqa_gqa.py implementado; los tests pasan.
  • Los cinco archivos del experimento commiteados.
  • La reducción del KV cache coincide con la fórmula teórica dentro del 1%.
  • README responde las cuatro preguntas.

Errores típicos

  • Olvidar repeat-interleave de KV antes del matmul. Si Q tiene 4 cabezas y KV tiene 1 cabeza (MQA), necesitas expandir KV a 4 cabezas (vía K.repeat_interleave(4, dim=heads_dim) o un einsum equivalente) antes del producto escalar. De lo contrario, las formas no casan.
  • El PPL no cambió nada. Probablemente no intercambiaste realmente la attention — comprueba que id(model.layers[0].attn) es distinto del de la referencia MHA.
  • El descenso del recuento de parámetros no casa. Quizá dejaste los biases KV en FP16 al mismo tamaño; ignora biases o calcula el recuento sólo de pesos.

Cuándo consultar solutions/

Tras cumplir todas las condiciones de parada. solutions/04-mqa-gqa-ref.md (apertura de fase) recorre la lógica del repeat-interleave.


Fin de los labs de la Fase 27. Escribe PHASE_27_REPORT.md a continuación.