English · Español
Lab 01 — Multi-head attention¶
Objetivo: extender single-head attention a multi-head attention, fijar la API de la clase
MultiHeadAttentiony verificar que con \(H = 1\) la salida coincide exactamente con la del lab previo.Tiempo estimado: 90–120 minutos.
Requisito previo: lab 00 commiteado.
Qué produces¶
Un directorio experiments/15-multi-head/ que contiene:
mha.py— tu implementación multi-head en NumPy, importando desdesrc/minimodel/attention/attention.py.verify.py— script de verificación.verify_output.txt— salida capturada.heatmap.png— figura de 4 paneles: patrón de atención para cada una de las \(H = 4\) cabezas sobre una entrada fija.manifest.json.README.md.
Contexto¶
theory/03-multi-head.md cubre:
- La construcción split-and-stack.
- Conteo de parámetros: \(4 d_\text{model}^2\) frente a \(3 d_\text{model}^2\) del caso single-head.
- Por qué la proyección de salida \(W_O\) es esencial.
- El truco de implementación "una matriz grande por rol, reshape en tiempo de ejecución".
src/minimodel/attention/BLUEPRINT.md (¡léelo!) fija la API de la clase:
class MultiHeadAttention:
def __init__(self, d_model: int, n_heads: int, seed: int = 0) -> None: ...
def forward(self, x: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray: ...
La clase posee cuatro matrices de pesos: W_Q, W_K, W_V, W_O, cada una \(d_\text{model} \times d_\text{model}\).
TODOs¶
Bloque A — implementar la clase¶
- En
src/minimodel/attention/attention.py, implementaMultiHeadAttention. - En
__init__: reserva las cuatro matrices usandonp.random.default_rng(seed). Escala por1 / sqrt(d_model)(init de la Fase 10). - En
forward: - Calcula
Q = x @ self.W_Q, igual para K, V. Forma(T, d_model). - Haz reshape de cada una a
(T, n_heads, d_head), transpón a(n_heads, T, d_head). - Para cada cabeza independientemente (o con einsum / matmul por batches):
scores = Q_h @ K_h.T / sqrt(d_head)— forma(T, T).- Aplica la máscara si se proporciona (aditiva).
attn = softmax(scores).out_h = attn @ V_h— forma(T, d_head).
- Reshape y concatena las cabezas de vuelta a
(T, d_model). - Aplica
out @ self.W_O. Devuelve. - Apunta a ≤ 30 LOC en el forward. Usa
einsumsi ayuda a la legibilidad — pero un buclefor h in range(H)también vale por claridad.
Bloque B — verificar equivalencia con single-head¶
Con \(H = 1\), la clase multi-head debe comportarse exactamente como la función single-head del lab 00 (salvo la proyección de salida).
- Construye un
MultiHeadAttention(d_model=4, n_heads=1, seed=42). - Extrae manualmente sus
W_Q, W_K, W_V(forma(4, 4)) — llama asingle_head_attention(X @ W_Q, X @ W_K, X @ W_V)del lab 00. - Luego aplica
W_Oa ese resultado. - Compara contra
mha.forward(X). - Ambas deben coincidir con tolerancia 1e-5. Comprueba con assert.
Bloque C — explorar: especialización de cabezas¶
Para una entrada fija — la secuencia canónica de 8 tokens del corpus de gramática verbal <bos> I work , you work , he (usa la tokenización del lab 00 de la Fase 14):
- Construye
MultiHeadAttention(d_model=64, n_heads=4, seed=0). - Ejecuta el forward. Captura las matrices de atención
attn_hde cada cabeza (modifica el forward para que opcionalmente las devuelva, o guárdalas como efecto secundario). - Dibuja 4 heatmaps en una cuadrícula 2×2. Cada heatmap es \(T \times T\), con la probabilidad de atención como color (usa
viridis). - Anota los ejes con los tokens decodificados reales.
- Guárdalo como
heatmap.png.
Con pesos aleatorios, las cabezas son patrones aleatorios — eso está bien. Lo que importa es la forma, no la semántica. (Los patrones de atención entrenada aparecen en la Fase 18.)
Bloque D — redactar¶
En README.md:
- Confirma la equivalencia con single-head (Bloque B). Indica la diferencia máxima.
- Describe la forma de los heatmaps (Bloque C). ¿Son distinguibles las cuatro cabezas entre sí? Con pesos aleatorios, deberían verse distintas (aleatorio es distinto de aleatorio). Anota cualquier patrón que veas (probablemente ninguno — es el resultado nulo esperado para init aleatorio).
Bloque E — manifest¶
{
"experiment": "15-multi-head",
"date": "YYYY-MM-DD",
"seed": 42,
"versions": { "python": "3.11.x", "numpy": "X.Y.Z", "matplotlib": "X.Y.Z" },
"config": {
"d_model": 64,
"n_heads": 4,
"T": 8,
"input_snippet": "<bos> I work , you work , he"
},
"results_summary": {
"single_head_equivalence_max_diff": null,
"heads_visibly_distinct": null
}
}
Restricciones¶
- Sin PyTorch.
- Reshape, no bucles, siempre que sea posible. Ambas funcionan; reshape es más rápido y es lo que hace el código de producción. El bucle se permite por claridad en tu primera pasada.
- La máscara es
Noneen este lab. El lab 02 añade la máscara causal (causal mask).
Condiciones de parada¶
Hecho cuando:
- Los seis archivos están commiteados.
- La aserción de equivalencia con single-head pasa (
max_diff < 1e-5). - El heatmap muestra cuatro patrones visiblemente distintos (aunque sean sin estructura).
README.mdresponde a ambas preguntas del Bloque D.
Trampas¶
- El orden del reshape importa.
x.reshape(T, H, d_head)es distinto dex.reshape(T, d_head, H). El primero deja las features de cada cabeza adyacentes en memoria; el segundo no. Usa el primero. - Transpose para el matmul. Tras hacer reshape a
(T, H, d_head), necesitas(H, T, d_head)para el matmul por batches. Usa.transpose(1, 0, 2). - ¿Sesgo (bias) en
W_Q,W_K,W_V? La Fase 15 usa sin sesgo en estas proyecciones (estándar en transformers desde 2017). \(W_O\) tampoco lleva sesgo. Documéntalo enREADME.md. - No te olvides de \(W_O\). Concatenar las cabezas no es el final. La proyección de salida es esencial.
- Verificar con pesos aleatorios. Los patrones no van a parecerse a atención entrenada. Es lo esperado.
Cuándo consultar solutions/¶
Cuando los seis archivos estén commiteados y la aserción del Bloque B pase. Solución en solutions/01-multi-head-ref.md.
Siguiente lab: 02-causal-mask.md.