English · Español
Break — entrenar sin clip de gradiente; reproducir una spike de loss a propósito¶
Apagamos el clip de gradiente y dejamos que un token raro del corpus §A13 (
ttendewritten) genere una spike cuando aparece concentrado en un batch. Es la versión sintética del post-mortem del archivo de teoría 04 — ahora Borja la causa, no la observa.
Síntoma que Borja verá¶
Dos runs:
- Run A (control): umbral de
grad-clip= 1.0, batching por defecto, seed 42. - Run B (break): umbral de
grad-clip= \(\infty\) (efectivamente sin clipping; en código:clip = float("inf")), misma seed, mismo batching.
Hacia el paso ~312 del Run B, el panel de loss mostrará una spike vertical de ~2.3 a ~12+, y o bien:
- (60% probable) se recupera lentamente en 100-200 pasos, asentándose 0.5-1.0 por encima de la curva de loss del Run A, con una línea base de norma de gradiente permanentemente elevada;
- (40% probable) diverge a NaN en 5 pasos y nunca se recupera.
El panel de grad-norm mostrará una única spike aislada de 30-80× la línea base en ese paso.
La rotura, mecánicamente¶
En experiments/19-break-no-clip/config.yaml:
O en código: en src/minitrain/loop.py, cambia
por
Y ya está. Toda la rotura es quitar una red de seguridad.
Por qué esto enseña el concepto¶
A escala §A13, el tokenizador BPE (Fase 11) produce un vocabulario de tokens donde el verbo write y sus conjugaciones (writes, wrote, written, writing) se parten en secuencias multi-token. El token tten (de written) es raro — aparece unas 5 veces en las 240 frases del set de entrenamiento, sólo dentro de conjugaciones de write.
Cuando el batching estocástico mete por azar 3 frases que contienen written en el mismo batch de 8, la fila de embedding del token raro recibe una señal de gradiente proporcional a 3 instancias de "esta fila se equivocó por ~\(\ln V\) nats". El gradiente de una sola fila tiene norma de Frobenius \(\sim 50\), y la norma global de gradiente está dominada por esta única fila.
Sin clipping: el optimizador da un paso gigante sobre la fila de embedding de tten (y un paso menor-pero-aún-grande sobre cualquier otro parámetro, porque las estimaciones de momentos de AdamW son globales). El modelo se mueve a una parte del espacio de parámetros donde:
- El embedding de
ttense ha sobrepasado — los gradientes del siguiente batch sobrecorrigen, oscilando. - Las estimaciones de momentos \(v_t\) contienen ahora una spike que tarda \(\sim 1/(1 - \beta_2) = 20\) pasos en desvanecerse.
- Otros parámetros se han actualizado con
lr · m̂ / √v̂dondev̂es más pequeño de lo que debería para este batch (se actualizó con lag²más pequeña del batch anterior), así que sus updates son también demasiado agresivos.
Resultado: la spike no es un evento de un solo paso, sino una desestabilización multi-paso. La "recuperación" que muestra la curva de loss es en realidad el optimizador re-calibrando lentamente sus momentos tras una corrupción.
Esta es la versión a escala §A13 del modo de fallo que Chowdhery et al. (2022) describen para PaLM. Misma forma, números más pequeños.
Escalera diagnóstica que Borja debe recorrer¶
- Primera comprobación: el panel de loss. La spike está en el paso 312, nítida e inconfundible.
- Segunda comprobación: el panel de
grad-norm. La norma pre-clip en el paso 312 es ~50, línea base ~0.6. La norma post-clip es... también ~50 (porque no hay clip). Esta es la pistola humeante. - Tercera comprobación: el log de composición del batch en el paso 312 (la instrumentación de la Fase 19 lo incluye). Muestra 3 frases que contienen el verbo
writeen forma de pasado participio. - Cuarta comprobación: el histograma de loss por token en el paso 312. Hay una cola derecha pesada con masa concentrada en el token
tten. - Diagnóstico: token raro + batch concentrado + sin clip = desestabilización de un solo paso.
Reproductor¶
# Control
seed=42 grad_clip=1.0 just phase-19-train
# Break
seed=42 grad_clip=inf just phase-19-train
# Compara
just phase-19-compare experiments/19-control experiments/19-break-no-clip
Cascada de pistas¶
- (Suave) "Mira el panel de
grad-normcerca de la spike de loss. ¿Algo antes de ese panel insinúa la causa?" - (Media) "¿Cuál es la norma de gradiente post-clip en el paso 312? ¿Qué te dice sobre el umbral del clip?"
- (Directa) "El clip está desactivado. Con un token raro concentrado en un batch, ¿cuál es el impacto de un solo paso sobre las estimaciones de momentos del optimizador?"
Fix¶
Restaurar grad_clip = 1.0. O, para enseñar una lección complementaria, restaurar grad_clip = 0.5 y observar que el umbral algo más estrecho deja la media móvil de la norma de gradiente (0.6) justo por debajo del clip, así que la mayoría de pasos no se ven afectados, pero la spike del paso 312 sí queda contenida.
Cualquiera de los dos fixes demuestra: el gradient clipping (clip de gradiente) es la defensa barata contra este modo de fallo. El fix más profundo — batching estratificado para impedir la concentración de tokens raros — es la defensa correcta, pero requiere un cambio en el data-loader.
Lo que esta rotura NO es¶
- No es una rotura por overflow numérico (estamos en fp32 todo el rato).
- No es una rotura de init (el modelo arranca sano, la spike ocurre en el paso 312, no en el paso 0).
- No es una rotura de schedule de LR (el LR es un coseno suave).
Es una rotura de defensas-retiradas, y enseña que el grad-clip no es opcional — es el seguro barato que permite al optimizador sobrevivir a una concentración probabilística de tokens de cola larga en un único batch.
Referencias cruzadas¶
theory/04-loss-spike-postmortem-template.md— el ejemplo trabajado coincide con esta rotura.stability-check.md§2 — el árbol de decisión para detectar spikes.- Fase 18
theory/02-optimizer-and-schedule.md— la matemática del clipping de gradiente.