Skip to content

English · Español

04 — Estabilidad del Entrenamiento a Escala

🇪🇸 A escala, los modelos no son estables por defecto. La curva de loss se vuelve un campo minado: spikes, divergencias, NaNs por Adam β₂. Los recursos que se gastan diagnosticando un solo spike a las 14 horas pueden costar más que toda la corrida de X1. Conocer el repertorio de respuesta (skip-batch, μP, gradient clip, restart-from-prior-ckpt) es la diferencia entre éxito y un cráter de $50k.

Un modelo de 50M parámetros rara vez spikee. Un modelo de 7B+ spikee regularmente, y a $10k+/h de tiempo de cluster (cluster), cada minuto de spike no reconocido es dinero real. El playbook de laboratorios frontera para estabilidad es prevención + detección + recuperación, cada uno con sus propios mecanismos.

Qué es un loss spike

Un loss spike es un salto súbito de 2-100× en el loss de entrenamiento sobre <100 steps, a menudo precedido por un spike de norma del gradiente. Tres trayectorias de la literatura canónica:

  1. Recuperable. El loss sube de \(L=2.1\) a \(L=4.5\) sobre 20 steps, luego decae de vuelta a \(L=2.1\) sobre los siguientes 500 steps. Se reanuda la trayectoria pre-spike. Coste: ~500 steps de cómputo desperdiciado.
  2. Persistente. El loss sube, se recupera parcialmente pero se estabiliza más alto que el pre-spike. Daño permanente; el modelo nunca alcanza. Coste: re-lanzar desde un checkpoint previo.
  3. Divergente. El loss sube y nunca se recupera. El loss continúa subiendo u oscila salvajemente. Aparecen NaNs. Coste: hard restart, a veces re-diseño.

Empíricamente: ~80% de los spikes son tipo 1 (recuperables), ~15% tipo 2, ~5% tipo 3 — pero esto depende mucho de arquitectura y config.

El mecanismo dominante: gradientes de gran magnitud por datos raros

La causa raíz más común del spike:

  1. Un batch raro contiene tokens o secuencias-de-tokens extremadamente sub-representadas en el entrenamiento hasta el momento.
  2. El modelo asigna probabilidad muy baja a esos tokens.
  3. El loss de cross-entropy para esos tokens es grande (p. ej. \(-\log(10^{-8}) = 18.4\)).
  4. El gradiente retropropagado es correspondientemente grande.
  5. El optimizador (Adam) da un paso grande.
  6. El paso empuja al modelo a una región del espacio de parámetros donde las activaciones explotan o las probs de atención (attention) se vuelven degeneradas.
  7. El loss del siguiente batch es enorme → spike.

Este mecanismo es load-bearing sobre β₂ de Adam. El segundo momento \(v\) de Adam promedia \(\hat{g}^2\) con tasa \(1-β_2\) (default 0.999, así que \(v\) tiene ventana de promediado efectiva de 1000 steps). Cuando llega un gradiente 100×, \(v\) no se actualiza lo suficientemente rápido para amortiguar el paso. El siguiente paso también es grande porque \(v\) todavía está stale.

La solución (PaLM, Chowdhery 2022): bajar β₂ a 0.95. Ahora \(v\) se adapta en ~20 steps, amortiguando gradientes anómalos mucho más rápido. Estándar para entrenamiento frontera post-2022.

μP (Maximal Update Parametrization)

Yang & Hu 2021 ("μTransfer") proponen una re-parametrización en la que el learning rate óptimo (y otros HPs) es invariante al ancho del modelo. El mecanismo:

  • Inicializa cada capa (layer) con varianza \(\sigma^2 \propto 1/\text{fan\_in}\) (estándar).
  • Escala la salida forward de cada capa por \(1/\sqrt{\text{fan\_in}}\) en lugar de dejar que la varianza crezca.
  • Usa un learning rate que escala como \(1/\text{ancho}\) para matrices de pesos.

Por qué ayuda a la estabilidad. La parametrización estándar tiene la propiedad de que, al escalar el ancho del modelo, las magnitudes de activación derivan. El learning rate "correcto" a 100M es demasiado pequeño a 1B y demasiado grande a 70B. μP fija la magnitud de activación a través de anchos.

El truco de transferencia de HP (Yang 2022): afina el learning rate sobre un modelo pequeño (p. ej. 100M), luego aplica el mismo LR a 70B bajo μP. Cuesta ~\(1k de búsqueda de HP de modelo pequeño en vez de ~\)100k de búsqueda de HP de modelo grande. Este es el uso industrial principal de μP — Cerebras, Eleuther, OLMo todos lo usan.

Para X1 no implementamos μP — una sola corrida de 50M no se beneficia. Pero deberías ser capaz de leer un fichero de config con mup_base_width=256 y saber qué significa.

Weight decay y estabilidad

Weight decay (el regularizador L2 sobre los pesos) interactúa sutilmente con Adam:

  • AdamW (Loshchilov 2017): desacopla weight decay del gradiente. Coeficiente de weight decay default 0.1.
  • Demasiado bajo (<0.01): los pesos derivan a gran magnitud, el pre-softmax de la atención se satura, el flujo de gradiente se degrada, inestabilidad eventual.
  • Demasiado alto (>0.5): el modelo no puede ajustar los datos, el entrenamiento se estanca alto.

Default de laboratorio frontera: AdamW con wd=0.1, β=(0.9, 0.95), gradient clip a 1.0. Usa estos para X1.

Gradient clipping

Casi universal en pretraining de transformer:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Cuando la norma L2 global del gradiente excede 1.0, escala todos los gradientes hacia abajo para hacer la norma exactamente 1.0. Esta es la primera línea de defensa contra los spikes. Sin ella, un único batch malo puede arruinar la corrida.

Default: max_norm=1.0. PaLM usó 1.0; Llama-2 usó 1.0; OLMo usó 1.0. El número no es mágico — es lo suficientemente pequeño para acotar gradientes patológicos y lo suficientemente grande para que los gradientes sanos no estén restringidos.

Detección: qué loggear

Para el lab 00 de X1, loguea cada 10 steps:

  • Loss (el obvio).
  • Norma del gradiente (L2 global) — pre-clip. Un spike aquí precede al loss spike por ~1-3 steps.
  • Norma de parámetros (L2 global) — derivando hacia arriba = inestabilidad gestándose.
  • Norma de \(v\) de Adam — un spike aquí = β₂ es demasiado alto para la varianza actual de los datos.
  • Learning rate — fácil de olvidar, fácil de pasar por alto bugs de schedule.
  • Tokens-por-segundo — caídas súbitas indican stalls del dataloader, problemas de NVLink o throttle térmico.
  • Norma L2 de activaciones en la capa N (muestreada) — para diagnósticos de etapa tardía.

mlflow (de la Fase 18) maneja todo esto con mlflow.log_metric(name, value, step=).

Recuperación: el playbook de respuesta

Si se detecta un spike mid-corrida, las opciones son:

Recuperación A: continuar y esperar

Si el loss se está recuperando por sí solo (spike tipo-1), no hagas nada. La mayoría de spikes se resuelven. Comprueba 200 steps después. Si todavía se recupera, déjalo.

Recuperación B: skip-batch

Salta los siguientes \(k\) batches y reanuda desde el estado del optimizador justo antes del spike. PaLM y otros usan \(k=5..20\). Esto funciona porque el spike fue causado por ese batch específico; saltarlo elimina el detonante.

Pseudocódigo:

if grad_norm > spike_threshold:
    save_state("pre_spike.ckpt")
    skip_next_n_batches = 10

Recuperación C: reiniciar desde checkpoint previo

Para spikes tipo-2 / tipo-3, el estado del optimizador está corrupto; ninguna cantidad de skip-batch ayuda. Reinicia desde un checkpoint ≥1000 steps antes del spike.

Esto es por qué la cadencia de checkpoint importa. Guardar cada 1000 steps (unos pocos minutos de cómputo) te da ~30 min perdidos en un reinicio. Guardar cada 100k steps te da 5 horas perdidas.

Cadencia del lab 00 de X1: cada 30 minutos (~50k steps a nuestro throughput). Aceptable para una corrida de 24 horas.

Recuperación D: bajar el LR

Si los spikes recurren a través de reinicios, el LR es demasiado alto para este punto del entrenamiento. Baja LR por 2-3× y reinicia. Coste: re-tunear.

Recuperación E: intercambio de datos

Si un subset específico del dataset está envenenando la corrida (p. ej., un dump de CommonCrawl tiene demasiado boilerplate repetido), intercámbialo y continúa.

Llama-2 documenta hacer esto; la "intervención mid-training" es una norma de laboratorio frontera.

Precisión numérica: bf16 vs fp16 vs fp8

X1 usa bf16. Las razones:

  • fp16: exponente de 5 bits, mantisa de 10 bits. Rango dinámico demasiado estrecho para gradientes de transformer sin loss-scaling (Micikevicius 2017). Loss-scaling es una pieza móvil extra que puede causar spikes por sí misma.
  • bf16: exponente de 8 bits (igual que fp32), mantisa de 7 bits. Rango dinámico coincide con fp32; la precisión es más baja. No se necesita loss-scaling. El default de pretraining desde 2022.
  • fp8: 4 o 5 bits de exponente, 2 o 3 bits de mantisa. Era Hopper (H100+). Entrenar en fp8 está al filo; FP8-LM (Peng 2023), paper de Nvidia Transformer Engine. No en el scope de X1.

X1 entrena en bf16 con pesos maestros fp32 para el optimizador (precisión mixta estándar). Loss-scaling no se usa (el rango dinámico de bf16 lo cubre).

Trucos de estabilidad a nivel de arquitectura

  • Pre-LN sobre post-LN. "Pre-LN" pone LayerNorm antes de la atención/FFN; "post-LN" lo pone después. Pre-LN es mucho más estable para stacks profundos. Default moderno. (Wang 2019, "Learning Deep Transformers with Latent Depth").
  • RMSNorm sobre LayerNorm. Ligeramente más rápido, sin coste observado de estabilidad. Llama-1 en adelante.
  • SwiGLU sobre ReLU/GeLU. Mejor rendimiento empírico; misma estabilidad. Shazeer 2020.
  • QK-norm. Normaliza queries y keys antes del producto punto. Usado por Chameleon, Idefics-2. Reduce la explosión de logits de atención.
  • Z-loss. Loss auxiliar sobre la norma de la log-función-de-partición. Usado por PaLM. Penaliza logits extremos.

X1 usa pre-LN + RMSNorm + SwiGLU (el default moderno). Sin QK-norm, sin z-loss — ayudan pasado ~1B parámetros más que a 50M.

Intervenciones mid-training

Para corridas más largas que ~10 días, los laboratorios frontera intervienen: resets de LR, intercambios de datos, reinicios con config diferente. Ejemplos:

  • Paper de Llama-3 (Meta 2024): describe múltiples resets del schedule de LR y cambios de currículo de datos mid-corrida.
  • OLMo (Groeneveld 2024): documenta 4 intercambios de datos mid-training y 1 ajuste de arquitectura.
  • BLOOM (BigScience 2022): loggeó cada spike públicamente; referencia clásica de postmortem abierto.

El principio: una corrida de pretraining no es "ejecutar y mirar". Es un sistema observado durante 1-3 meses, monitoreado por 2-5 ingenieros, con intervenciones loggeadas como ensayos clínicos.

Para la corrida de 24 horas de X1, la cadencia de intervención es mucho más corta, y el lab incluye un paso de inyección de spike para que practiques el procedimiento de respuesta.

Plantilla de post-mortem de loss-spike

Requerido para el check 3 de DoD de X1. Estructura:

# Spike #N — YYYY-MM-DD HH:MM UTC

## Síntomas
- Loss pre-spike: 3.4
- Loss pico: 8.1
- Loss de recuperación (500 steps después): 3.6
- Pico de norma del gradiente: 47.3 (pre-clip)

## Evidencia
- Corrida mlflow: <URI>
- Rango de índice de batch: 14,250–14,280
- Histograma de tokens para los batches ofensores: [PNG adjunto]
- Norma L2 de activaciones por-capa en el spike: [log adjunto]

## Clasificación
- Tipo: 1 (recuperable) / 2 (persistente) / 3 (divergente)
- Causa raíz: gradiente de token raro / β₂ stale / LR demasiado alto / corrupción de datos

## Acción de recuperación
- Skip-batch (k=10) / restart-from-ckpt-N / bajada de LR 2× / intercambio de datos

## Resultado
- Curva de loss reanudada en la trayectoria pre-spike en el step 14,800.
- Coste del incidente: 550 steps × 500k tokens/s × $1.10/h / (3600 × 500k) = ~$0.17. Aceptable.

## Lecciones / followups
- (p. ej.) Bajar β₂ de 0.99 a 0.95 en la próxima corrida.
- (p. ej.) Añadir logging de entropía por-batch.

Qué NO pertenece a este fichero de teoría

  • Diverger-y-culpar-al-optimizador. "Divergió; cambiamos a un nuevo optimizador" es un mito popular. El optimizador rara vez es la causa raíz; los datos y la inicialización suelen serlo. Resístete a esa narrativa.
  • El cookbook de hiperparámetros. Este fichero es mecanismos, no ajustes. El lab 00 citará HPs específicos.
  • El bake-off de gran presupuesto. Comparar AdaFactor / Adam / Lion / Shampoo a escala necesita una corrida real; la única config de X1 no lo justifica.

Recapitulación de un párrafo

Los loss spikes ocurren a escala y son dominantemente causados por gradientes grandes desde datos raros + β₂ de Adam siendo demasiado alto para amortiguarlos. El playbook de laboratorios frontera es prevenir (β₂=0.95, grad clip 1.0, wd=0.1, pre-LN+RMSNorm+SwiGLU, bf16, μP para transferencia cross-escala de HP), detectar (norma del gradiente, norma de parámetros, norma de \(v\), loss loggeado cada 10 steps), recuperar (no-hacer-nada para tipo-1, skip-batch para causa conocida, restart-from-prior-ckpt para persistente, bajada de LR o intercambio de datos para recurrentes). El lab 00 de X1 entrega todo esto, con una inyección sintética de spike para que escribas el post-mortem.


Siguiente: lab/00-one-day-cloud-pretraining.md.