Skip to content

English · Español

Lab 03 — Causalidad por perturbación

Lee theory/01-transformer-block.md (§"Causal masking — still required, even with RoPE"). No consultes solutions/.

Objetivo

Verifica que la máscara causal de tu Mini-GPT está conectada correctamente desde la entrada hasta la salida: cambiar el token de entrada en la posición \(i\) debe modificar los logits de salida sólo en las posiciones \(j \ge i\). Las entradas en la posición \(i\) no deben filtrarse hacia atrás a las posiciones \(j < i\). Esta es la prueba de cordura que distingue un modelo de lenguaje autoregresivo real de uno bidireccional (estilo BERT). Si esto falla, el entrenamiento de la Fase 18 hará trampa silenciosamente.

Background

En un modelo de lenguaje autoregresivo, la predicción en la posición \(i\) debe depender sólo de las posiciones \(0, 1, \ldots, i\). La máscara causal lo impone asignando \(-\infty\) a las puntuaciones de attention desde la posición \(i\) a las posiciones \(j > i\) antes del softmax. La posición 0 sólo presta atención a la posición 0; la posición 7 presta atención a 0–7.

Prueba de cordura estándar: elige dos secuencias de entrada que difieran únicamente en la posición \(k\). Ejecuta ambas. La salida en cada posición \(j < k\) debe ser idéntica (se cumple la causalidad). La salida en cada posición \(j \ge k\) puede diferir (y casi seguro lo hará, salvo coincidencia).

Tareas

Tarea 1 — prueba de perturbación

En tests/test_phase17_causality.py:

def test_causal_mask_holds_end_to_end():
    model = MiniGPT(config)
    tokens_a = np.array([3, 1, 4, 1, 5, 9, 2, 6])  # arbitrary 8 tokens
    tokens_b = tokens_a.copy()
    perturb_at = 5
    tokens_b[perturb_at] = (tokens_a[perturb_at] + 7) % config.vocab_size  # change

    logits_a = model(tokens_a)
    logits_b = model(tokens_b)

    # Causality: positions 0..(perturb_at-1) must be identical.
    for j in range(perturb_at):
        assert np.allclose(logits_a[j], logits_b[j], atol=1e-8), \
            f"causality broken at position {j} (perturbed at {perturb_at})"

    # Positions perturb_at..T-1 should differ at least somewhere.
    differs = False
    for j in range(perturb_at, len(tokens_a)):
        if not np.allclose(logits_a[j], logits_b[j], atol=1e-8):
            differs = True
            break
    assert differs, "perturbation had no downstream effect — model not connected"

Es una prueba de un solo disparo. Ejecútala.

Tarea 2 — barrido por todas las posiciones

Repite la prueba para cada \(k\) desde 1 hasta \(T-1\). Cada pasada debe respetar la causalidad. Recoge los resultados en una tabla:

Posición de perturbación \(k\) ¿Posiciones 0..\(k-1\) idénticas? ¿Posiciones \(k\)..\(T-1\) difieren?
1 ✓ / ✗ ✓ / ✗
2 ... ...
...

Todas las filas deben ser ✓ ✓. Cualquier ✗ en la segunda columna es un bug fatal.

Tarea 3 — ¿qué falla cuando la máscara está mal?

Por motivos didácticos, desactiva temporalmente la máscara causal en tu implementación de MHA (no commitees este cambio). Vuelve a ejecutar la Tarea 1. Deberías ver que perturbar la posición 5 cambia la salida en las posiciones 0–4 — que es exactamente lo que hacen los modelos estilo BERT, y exactamente el bug que un modelo autoregresivo no debe tener.

Documenta qué ha cambiado. Reactiva la máscara. Este paso construye la intuición: la máscara está haciendo trabajo real.

Tarea 4 — verifica con una secuencia más larga

Ejecuta la prueba con longitud de secuencia \(T = 32\) (el context_len fijado). Perturba cada uno de \(k = 1, 8, 16, 31\). Confirma la causalidad en las cuatro posiciones.

Esto detecta bugs de RoPE que sólo aparecen con secuencias más largas (p.ej., una tabla RoPE que no se extiende más allá de 8).

Tarea 5 — prueba de causalidad sobre gradientes (a futuro)

Esta tarea es opcional — requiere que el autograd de la Fase 8 esté conectado a través del modelo. Si la saltas, documenta por qué.

Si puedes: calcula \(\partial \text{logits}_j / \partial \text{input\_embed}_i\) para unos cuantos pares \(i, j\) de muestra.

  • Para \(i > j\): el gradiente debe ser cero (causalidad sobre el gradiente, no sólo sobre el paso forward (forward pass)).
  • Para \(i \le j\): el gradiente debe ser generalmente distinto de cero.

La prueba de gradiente es más fuerte que la de perturbación, porque detecta filtraciones sutiles que la perturbación podría no ver por coincidencia.

Mediciones a capturar

  • Tabla del barrido (Tarea 2): todas las filas ✓ ✓.
  • Diferencia contrafáctica (Tarea 3): documenta un caso donde el modelo sin máscara filtró información hacia atrás.
  • Barrido con secuencia larga (Tarea 4): las cuatro posiciones de perturbación respetan la causalidad.
  • (Opcional) Resultado de la causalidad sobre gradientes, si conseguiste hacer funcionar la Tarea 5.

Guarda los resultados en experiments/<date>-phase-17-causality/manifest.json más un CSV del barrido.

Aceptación

  • test_phase17_causality.py existe y pasa.
  • Tabla del barrido rellena para todo \(k\) desde 1 hasta \(T-1\).
  • Tarea 3 documentada: el modelo sin máscara rompe la causalidad; el modelo con máscara no.
  • La prueba con \(T = 32\) pasa.
  • Las notas del lab incluyen un párrafo sobre por qué RoPE por sí solo no es suficiente para la causalidad.

Trampas a esperar

  • Atol demasiado laxo. Usa atol=1e-8 para el chequeo de "idénticos", no 1e-5. La aritmética en punto flotante da identidad bit-exacta para el prefijo cuando la máscara es correcta; si necesitas holgura, tienes un bug.
  • Máscara aplicada en la fase equivocada. Bug común: máscara aplicada después del softmax en lugar de antes. Después del softmax, los tokens futuros ya han contribuido a la función de partición — tu modelo "causal" filtra. Siempre pre-softmax: suma -inf a las puntuaciones, luego softmax.
  • Forma de la máscara. Forma (T, T) con mask[i, j] = -inf if j > i else 0. Hace broadcast contra puntuaciones (n_heads, T, T) limpiamente.
  • RoPE aplicado después de la máscara. El orden importa: proyectar a Q/K, aplicar RoPE, calcular puntuaciones, aplicar máscara, softmax, suma ponderada de V, proyección de salida. Si RoPE se aplica después de la máscara (algo que se ve en código antiguo), puedes corromper el enmascarado.

Siguiente: Fase 18 — Bucle de entrenamiento, vista previa de precisión mixta, checkpointing (después de /quiz 17 y /phase-report 17).