Skip to content

English · Español

Lab 01 — Multi-head attention

Objetivo: extender single-head attention a multi-head attention, fijar la API de la clase MultiHeadAttention y verificar que con \(H = 1\) la salida coincide exactamente con la del lab previo.

Tiempo estimado: 90–120 minutos.

Requisito previo: lab 00 commiteado.


Qué produces

Un directorio experiments/15-multi-head/ que contiene:

  • mha.py — tu implementación multi-head en NumPy, importando desde src/minimodel/attention/attention.py.
  • verify.py — script de verificación.
  • verify_output.txt — salida capturada.
  • heatmap.png — figura de 4 paneles: patrón de atención para cada una de las \(H = 4\) cabezas sobre una entrada fija.
  • manifest.json.
  • README.md.

Contexto

theory/03-multi-head.md cubre: - La construcción split-and-stack. - Conteo de parámetros: \(4 d_\text{model}^2\) frente a \(3 d_\text{model}^2\) del caso single-head. - Por qué la proyección de salida \(W_O\) es esencial. - El truco de implementación "una matriz grande por rol, reshape en tiempo de ejecución".

src/minimodel/attention/BLUEPRINT.md (¡léelo!) fija la API de la clase:

class MultiHeadAttention:
    def __init__(self, d_model: int, n_heads: int, seed: int = 0) -> None: ...
    def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray: ...

La clase posee cuatro matrices de pesos: W_Q, W_K, W_V, W_O, cada una \(d_\text{model} \times d_\text{model}\).

TODOs

Bloque A — implementar la clase

  • En src/minimodel/attention/attention.py, implementa MultiHeadAttention.
  • En __init__: reserva las cuatro matrices usando np.random.default_rng(seed). Escala por 1 / sqrt(d_model) (init de la Fase 10).
  • En forward:
  • Calcula Q = x @ self.W_Q, igual para K, V. Forma (T, d_model).
  • Haz reshape de cada una a (T, n_heads, d_head), transpón a (n_heads, T, d_head).
  • Para cada cabeza independientemente (o con einsum / matmul por batches):
    • scores = Q_h @ K_h.T / sqrt(d_head) — forma (T, T).
    • Aplica la máscara si se proporciona (aditiva).
    • attn = softmax(scores).
    • out_h = attn @ V_h — forma (T, d_head).
  • Reshape y concatena las cabezas de vuelta a (T, d_model).
  • Aplica out @ self.W_O. Devuelve.
  • Apunta a ≤ 30 LOC en el forward. Usa einsum si ayuda a la legibilidad — pero un bucle for h in range(H) también vale por claridad.

Bloque B — verificar equivalencia con single-head

Con \(H = 1\), la clase multi-head debe comportarse exactamente como la función single-head del lab 00 (salvo la proyección de salida).

  • Construye un MultiHeadAttention(d_model=4, n_heads=1, seed=42).
  • Extrae manualmente sus W_Q, W_K, W_V (forma (4, 4)) — llama a single_head_attention(X @ W_Q, X @ W_K, X @ W_V) del lab 00.
  • Luego aplica W_O a ese resultado.
  • Compara contra mha.forward(X).
  • Ambas deben coincidir con tolerancia 1e-5. Comprueba con assert.

Bloque C — explorar: especialización de cabezas

Para una entrada fija — la secuencia canónica de 8 tokens del corpus de gramática verbal <bos> I work , you work , he (usa la tokenización del lab 00 de la Fase 14):

  • Construye MultiHeadAttention(d_model=64, n_heads=4, seed=0).
  • Ejecuta el forward. Captura las matrices de atención attn_h de cada cabeza (modifica el forward para que opcionalmente las devuelva, o guárdalas como efecto secundario).
  • Dibuja 4 heatmaps en una cuadrícula 2×2. Cada heatmap es \(T \times T\), con la probabilidad de atención como color (usa viridis).
  • Anota los ejes con los tokens decodificados reales.
  • Guárdalo como heatmap.png.

Con pesos aleatorios, las cabezas son patrones aleatorios — eso está bien. Lo que importa es la forma, no la semántica. (Los patrones de atención entrenada aparecen en la Fase 18.)

Bloque D — redactar

En README.md:

  1. Confirma la equivalencia con single-head (Bloque B). Indica la diferencia máxima.
  2. Describe la forma de los heatmaps (Bloque C). ¿Son distinguibles las cuatro cabezas entre sí? Con pesos aleatorios, deberían verse distintas (aleatorio es distinto de aleatorio). Anota cualquier patrón que veas (probablemente ninguno — es el resultado nulo esperado para init aleatorio).

Bloque E — manifest

{
  "experiment": "15-multi-head",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "versions": { "python": "3.11.x", "numpy": "X.Y.Z", "matplotlib": "X.Y.Z" },
  "config": {
    "d_model": 64,
    "n_heads": 4,
    "T": 8,
    "input_snippet": "<bos> I work , you work , he"
  },
  "results_summary": {
    "single_head_equivalence_max_diff": null,
    "heads_visibly_distinct": null
  }
}

Restricciones

  • Sin PyTorch.
  • Reshape, no bucles, siempre que sea posible. Ambas funcionan; reshape es más rápido y es lo que hace el código de producción. El bucle se permite por claridad en tu primera pasada.
  • La máscara es None en este lab. El lab 02 añade la máscara causal (causal mask).

Condiciones de parada

Hecho cuando:

  1. Los seis archivos están commiteados.
  2. La aserción de equivalencia con single-head pasa (max_diff < 1e-5).
  3. El heatmap muestra cuatro patrones visiblemente distintos (aunque sean sin estructura).
  4. README.md responde a ambas preguntas del Bloque D.

Trampas

  • El orden del reshape importa. x.reshape(T, H, d_head) es distinto de x.reshape(T, d_head, H). El primero deja las features de cada cabeza adyacentes en memoria; el segundo no. Usa el primero.
  • Transpose para el matmul. Tras hacer reshape a (T, H, d_head), necesitas (H, T, d_head) para el matmul por batches. Usa .transpose(1, 0, 2).
  • ¿Sesgo (bias) en W_Q, W_K, W_V? La Fase 15 usa sin sesgo en estas proyecciones (estándar en transformers desde 2017). \(W_O\) tampoco lleva sesgo. Documéntalo en README.md.
  • No te olvides de \(W_O\). Concatenar las cabezas no es el final. La proyección de salida es esencial.
  • Verificar con pesos aleatorios. Los patrones no van a parecerse a atención entrenada. Es lo esperado.

Cuándo consultar solutions/

Cuando los seis archivos estén commiteados y la aserción del Bloque B pase. Solución en solutions/01-multi-head-ref.md.


Siguiente lab: 02-causal-mask.md.