English · Español
Break 00 — Quitar el escalado sqrt(d_k) en la atención¶
🇪🇸 Quitamos el
/ sqrt(d_k)del cálculo de los logits de atención. Ad_kpequeño no pasa nada visible. Ad_k = 64la varianza de los logits crece × 64, el softmax satura, los gradientes mueren. Demostramos el efecto con el §A13 yd_k ∈ {4, 16, 64}.Anchors:
LYNX_CORTEX.md§4 / PHASE 15; theory §02 scaled dot-product; theory §05 worked length-4;.claude/commands/break.md.
La rotura¶
En src/minimodel/nn/attention.py:
class ScaledDotProductAttention(Module):
def forward(self, Q: Tensor, K: Tensor, V: Tensor, mask: Tensor | None = None) -> Tensor:
d_k = Q.shape[-1]
# BUG: removed the /sqrt(d_k) scaling.
scores = Q @ K.transpose(-2, -1)
# was: scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores + mask
attn = scores.softmax(dim=-1)
return attn @ V
Edición de una sola línea.
Predice, luego ejecuta¶
Los logits pre-softmax tienen varianza proporcional a d_k. Con Q, K ~ N(0, 1) (init), cada entrada de Q K^T es una suma de d_k productos de normales estándar, por lo que:
Tras el softmax con logits de varianza d_k, la distribución se hace más afilada por un factor del orden de exp(sqrt(d_k)). Cuantitativamente, para un logit grande z_max y d_k - 1 más pequeños, el softmax se concentra hasta 1 - O(exp(-z_max)). Con d_k = 64, z_max ≈ sqrt(64) = 8, así que exp(-8) ≈ 3e-4 — la fila del softmax es [0.9997, ε, ε, ..., ε] — saturada.
Un softmax saturado tiene gradiente casi nulo: ∂softmax_i/∂z_j ≈ 0 para todo i ≠ j. El bloque de atención entero deja de aprender.
Predicciones¶
d_k = 4: sutil. La atención más afilada que con escalado pero la pérdida (loss) converge.d_k = 16: notable. La curva de loss se aplana 30-50% antes.d_k = 64: catastrófico. La loss apenas se mueve; la atención es one-hot desde el paso 1.- Inspeccionando
attn(la salida del softmax): las filas se ven como[1.0, 0, 0, ...]en lugar del esperado[0.25, 0.30, 0.20, 0.25].
Escribe las predicciones en learners/borja/phase-15/notes/breaks.md antes de ejecutar.
Observa¶
Diagnósticos:
- Dibuja
attn.max(axis=-1)por fila a lo largo de los pasos de entrenamiento — debería estar cerca de 1 inmediatamente si está roto. - Dibuja la curva de loss para
d_k ∈ {4, 16, 64}con y sin escalado. - Calcula la norma del gradiente en la proyección QKV en el paso 1 — debería ser ~0 en el caso roto con
d_k = 64.
Síntoma que verá Borja¶
d_k = 4: el entrenamiento funciona (por eso el bug es sutil en modelos pequeños).d_k = 16: el entrenamiento funciona pero converge un 30% más lento.d_k = 64: la loss es plana.attn.max(axis=-1).min() > 0.99(saturada).- Gradiente a través del bloque de atención: se desvanece.
Causa oculta (en una frase)¶
El producto Q K^T tiene varianza proporcional a d_k; sin la normalización /sqrt(d_k), el softmax satura para cualquier d_k ≥ ~16 y los gradientes se desvanecen.
Cascada de pistas¶
- Imprime
attn.max(axis=-1).mean()durante el entrenamiento. En init debería ser ~1/Tpara una ejecución sana. - Calcula la varianza empírica de los logits pre-softmax. Compárala con
d_k. ¿Qué dice Vaswani et al. 2017 §3.2.1 que debería pasar? - Mira el
forwardenattention.py. ¿Está el matmul dividido porsqrt(d_k)?
Diff del arreglo¶
Por qué esto enseña el concepto¶
Vaswani et al. 2017 §3.2.1 deriva el factor de escalado precisamente para controlar la saturación del softmax a medida que d_k crece. El bug es el bug sutil de atención más habitual porque pasa los tests con d_k=8 y se rompe con d_k=64. La rotura deja claro que sqrt(d_k) no es un detalle estético — es una pieza estructural de preservación de varianza que se vuelve esencial a cualquier tamaño realista de modelo. El mini-GPT de la Fase 17 usa d_k = 64 por cabeza; la Flash attention de la Fase 27 preserva este escalado dentro de su algoritmo por tiles.
Siguiente: el /break de la Fase 16 sobre codificaciones posicionales barajadas.