Skip to content

English · Español

Break 00 — Quitar el escalado sqrt(d_k) en la atención

🇪🇸 Quitamos el / sqrt(d_k) del cálculo de los logits de atención. A d_k pequeño no pasa nada visible. A d_k = 64 la varianza de los logits crece × 64, el softmax satura, los gradientes mueren. Demostramos el efecto con el §A13 y d_k ∈ {4, 16, 64}.

Anchors: LYNX_CORTEX.md §4 / PHASE 15; theory §02 scaled dot-product; theory §05 worked length-4; .claude/commands/break.md.


La rotura

En src/minimodel/nn/attention.py:

class ScaledDotProductAttention(Module):
    def forward(self, Q: Tensor, K: Tensor, V: Tensor, mask: Tensor | None = None) -> Tensor:
        d_k = Q.shape[-1]
        # BUG: removed the /sqrt(d_k) scaling.
        scores = Q @ K.transpose(-2, -1)
        # was: scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores + mask
        attn = scores.softmax(dim=-1)
        return attn @ V

Edición de una sola línea.

Predice, luego ejecuta

Los logits pre-softmax tienen varianza proporcional a d_k. Con Q, K ~ N(0, 1) (init), cada entrada de Q K^T es una suma de d_k productos de normales estándar, por lo que:

\[ \mathrm{Var}((Q K^T)_{ij}) = d_k \]

Tras el softmax con logits de varianza d_k, la distribución se hace más afilada por un factor del orden de exp(sqrt(d_k)). Cuantitativamente, para un logit grande z_max y d_k - 1 más pequeños, el softmax se concentra hasta 1 - O(exp(-z_max)). Con d_k = 64, z_max ≈ sqrt(64) = 8, así que exp(-8) ≈ 3e-4 — la fila del softmax es [0.9997, ε, ε, ..., ε]saturada.

Un softmax saturado tiene gradiente casi nulo: ∂softmax_i/∂z_j ≈ 0 para todo i ≠ j. El bloque de atención entero deja de aprender.

Predicciones

  • d_k = 4: sutil. La atención más afilada que con escalado pero la pérdida (loss) converge.
  • d_k = 16: notable. La curva de loss se aplana 30-50% antes.
  • d_k = 64: catastrófico. La loss apenas se mueve; la atención es one-hot desde el paso 1.
  • Inspeccionando attn (la salida del softmax): las filas se ven como [1.0, 0, 0, ...] en lugar del esperado [0.25, 0.30, 0.20, 0.25].

Escribe las predicciones en learners/borja/phase-15/notes/breaks.md antes de ejecutar.

Observa

just exp 15-train-attn --tag broken-no-scaling --d-k 64

Diagnósticos:

  1. Dibuja attn.max(axis=-1) por fila a lo largo de los pasos de entrenamiento — debería estar cerca de 1 inmediatamente si está roto.
  2. Dibuja la curva de loss para d_k ∈ {4, 16, 64} con y sin escalado.
  3. Calcula la norma del gradiente en la proyección QKV en el paso 1 — debería ser ~0 en el caso roto con d_k = 64.

Síntoma que verá Borja

  • d_k = 4: el entrenamiento funciona (por eso el bug es sutil en modelos pequeños).
  • d_k = 16: el entrenamiento funciona pero converge un 30% más lento.
  • d_k = 64: la loss es plana. attn.max(axis=-1).min() > 0.99 (saturada).
  • Gradiente a través del bloque de atención: se desvanece.

Causa oculta (en una frase)

El producto Q K^T tiene varianza proporcional a d_k; sin la normalización /sqrt(d_k), el softmax satura para cualquier d_k ≥ ~16 y los gradientes se desvanecen.

Cascada de pistas

  1. Imprime attn.max(axis=-1).mean() durante el entrenamiento. En init debería ser ~1/T para una ejecución sana.
  2. Calcula la varianza empírica de los logits pre-softmax. Compárala con d_k. ¿Qué dice Vaswani et al. 2017 §3.2.1 que debería pasar?
  3. Mira el forward en attention.py. ¿Está el matmul dividido por sqrt(d_k)?

Diff del arreglo

scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)

Por qué esto enseña el concepto

Vaswani et al. 2017 §3.2.1 deriva el factor de escalado precisamente para controlar la saturación del softmax a medida que d_k crece. El bug es el bug sutil de atención más habitual porque pasa los tests con d_k=8 y se rompe con d_k=64. La rotura deja claro que sqrt(d_k) no es un detalle estético — es una pieza estructural de preservación de varianza que se vuelve esencial a cualquier tamaño realista de modelo. El mini-GPT de la Fase 17 usa d_k = 64 por cabeza; la Flash attention de la Fase 27 preserva este escalado dentro de su algoritmo por tiles.


Siguiente: el /break de la Fase 16 sobre codificaciones posicionales barajadas.