English · Español
Lab 02 — Máscara causal (causal mask)¶
Objetivo: añadir enmascarado causal al forward de
MultiHeadAttention; verificar mediante perturbación que la salida en la posición \(i\) NO depende de la entrada en la posición \(i+1\).Tiempo estimado: 45–60 minutos.
Requisito previo: labs 00 y 01 commiteados;
theory/04-masking.mdleído.
Qué produces¶
Un directorio experiments/15-causal-mask/ que contiene:
mask.py— helper de máscara causal (causal mask).verify.py— script de test de perturbación.verify_output.txt— salida capturada.mask_visual.png— heatmap de la máscara causal + una matriz de atención con la máscara aplicada (lado a lado).manifest.json.README.md.
Contexto¶
La máscara causal: M[i, j] = 0 if j <= i else -inf. Aplicada aditivamente antes del softmax. El archivo de teoría 04 es la referencia.
El test de perturbación es la forma estándar de verificar una máscara causal en la práctica. Es mucho más convincente que leer el código:
- Ejecuta el modelo sobre la entrada \(X\), captura la salida \(Y\).
- Ejecuta el modelo sobre \(X'\) donde el último token difiere, captura la salida \(Y'\).
- Comprueba que \(Y[0..T-1]\) coincide con \(Y'[0..T-1]\) para todas las posiciones anteriores a la última.
Si la máscara funciona, la perturbación no puede propagarse hacia atrás en el tiempo. Si las posiciones anteriores a \(T-1\) difieren entre \(Y\) y \(Y'\), la máscara está rota.
TODOs¶
Bloque A — implementar el helper de la máscara¶
- En
src/minimodel/attention/attention.py, añadecausal_mask(T: int, dtype=np.float32) -> np.ndarray. - Devuelve una matriz \(T \times T\) con ceros en/por debajo de la diagonal y
-1e9por encima. - Usa
np.triu(np.ones(...), k=1) * -1e9. Una sola línea en el cuerpo. - Test unitario: para \(T = 4\), la máscara debe verse así
Bloque B — conectarla al forward¶
- Actualiza
MultiHeadAttention.forward(x, mask=None): - Si se pasa
mask, súmala ascoresantes del softmax. - Contrato de forma:
maskes(T, T)y se difunde por la dimensión de cabezas. - Re-ejecuta el lab 01 con
mask=Nonepara confirmar que no hay regresión (las aserciones del lab 01 siguen pasando).
Bloque C — el test de perturbación¶
En verify.py:
- Construye
mha = MultiHeadAttention(d_model=16, n_heads=2, seed=0). - Construye \(X\) de forma
(T=8, 16)con valores aleatorios semilla fija. - Construye \(X'\) =
X.copy(), luego ponX'[7] = nuevo vector aleatorio(semilla distinta). - Construye
mask = causal_mask(8). - Calcula
Y = mha.forward(X, mask=mask)yY' = mha.forward(X', mask=mask). - Comprueba
np.allclose(Y[0:7], Y'[0:7], atol=1e-6). (Las posiciones 0..6 deben ser idénticas entre Y e Y' — el cambio en la posición 7 no puede propagarse hacia atrás.) - Comprueba
not np.allclose(Y[7], Y'[7], atol=1e-3). (La posición 7 debe cambiar, ya que su propia entrada cambió.) - Imprime pass/fail y la diferencia máxima por posición. Captura a
verify_output.txt.
Bloque D — el modo de fallo (rotura intencional)¶
Verifica tu comprensión rompiendo la máscara intencionadamente y mirando cómo el test falla:
- Haz una copia del test donde la máscara se aplique multiplicativamente después del softmax (la forma incorrecta; ver
theory/04-masking.md). Es decir, calcula la atención como - Re-ejecuta el test de perturbación sobre esta versión rota.
- Confirma que FALLA — las posiciones anteriores de \(Y'\) ahora difieren de \(Y\) porque el gradiente/salida se ha filtrado a través de la normalización del softmax.
- Captura también esta salida. Anótalo en
README.md.
Bloque E — visualizar¶
- Dos subplots lado a lado:
- Izquierda: la propia máscara causal (visualiza 0 como blanco y -1e9 como negro).
- Derecha: la matriz de atención \(A\) después de aplicar la máscara a una matriz de scores aleatoria (visualiza 0 como blanco y 1 como oscuro).
- Ambas deben tener forma triangular inferior, y las filas de la derecha deben sumar 1 (porque se aplicó el softmax).
- Guárdalo como
mask_visual.png.
Bloque F — redactar¶
En README.md, responde:
- ¿Por qué falla el enmascarado multiplicativo post-softmax en el test de perturbación? Respuesta de dos frases; referénciate a
theory/04-masking.md§"The critical mistake". - ¿Por qué la suma por fila de la matriz de atención es exactamente 1, incluso después del enmascarado? (Pista: el softmax normaliza lo que sobreviva. Las posiciones prohibidas reciben exactamente 0, así que las probabilidades restantes suman 1.)
Bloque G — manifest¶
{
"experiment": "15-causal-mask",
"date": "YYYY-MM-DD",
"seed": 0,
"versions": { "python": "3.11.x", "numpy": "X.Y.Z" },
"config": {
"T": 8,
"d_model": 16,
"n_heads": 2
},
"results_summary": {
"correct_perturbation_max_diff_positions_0_6": null,
"broken_perturbation_max_diff_positions_0_6": null
}
}
La diferencia máxima en la versión correcta debe ser < 1e-6. La de la versión rota debe ser > 1e-3 (rota).
Restricciones¶
- Sin PyTorch.
-1e9, no-np.inf. Algunas reducciones de numpy se sorprenden con inf; un valor finito muy negativo es más seguro.- El test debe ser determinista. Pon semilla a
mha, semilla a \(X\) y a \(X'\).
Condiciones de parada¶
Hecho cuando:
- Los seis archivos están commiteados.
- El test de perturbación de la versión correcta pasa (diferencia máxima < 1e-6 en las posiciones 0..6).
- El test de perturbación de la versión rota demuestra el fallo (diferencia máxima > 1e-3).
README.mdresponde a ambas preguntas del Bloque F.
Trampas¶
- Off-by-one. La posición \(i\) atiende a las posiciones \(0, \ldots, i\) inclusive. Si usaras
k=0ennp.triuanularías la diagonal — mal (entonces la posición \(i\) no podría ni siquiera atenderse a sí misma, solo al pasado). - Desajuste de forma.
maskes(T, T). Debe difundirse por la dimensión de cabezas descores(que es(H, T, T)). NumPy lo hace automáticamente. - La máscara también se usa en inferencia. El enmascarado causal es necesario tanto en entrenamiento como en inferencia para un decodificador autorregresivo. No la desactives en inferencia.
Cuándo consultar solutions/¶
Cuando los seis archivos estén commiteados y ambos tests (correcto y de fallo) se comporten como se espera. Solución en solutions/02-causal-mask-ref.md.
Siguiente lab: 03-attention-perf.md.