English · Español
04 — Training Stability at Scale¶
🇪🇸 A escala, los modelos no son estables por defecto. La curva de loss se vuelve un campo minado: spikes, divergencias, NaNs por Adam β₂. Los recursos que se gastan diagnosticando un solo spike a las 14 horas pueden costar más que toda la corrida de X1. Conocer el repertorio de respuesta (skip-batch, μP, gradient clip, restart-from-prior-ckpt) es la diferencia entre éxito y un cráter de $50k.
A 50M-param model rarely spikes. A 7B+ model spikes regularly, and at $10k+/hr of cluster time, every minute of unrecognized spike is real money. The frontier-lab playbook for stability is prevention + detection + recovery, each with its own mechanisms.
What is a loss spike?¶
A loss spike is a sudden 2-100× jump in training loss over <100 steps, often preceded by a gradient norm spike. Three trajectories from the canonical literature:
- Recoverable. Loss spikes from \(L=2.1\) to \(L=4.5\) over 20 steps, then decays back to \(L=2.1\) over the next 500 steps. Pre-spike trajectory resumes. Cost: ~500 steps of wasted compute.
- Lingering. Loss spikes, partially recovers but stabilizes higher than pre-spike. Permanent damage; the model never catches up. Cost: re-launch from a prior checkpoint.
- Divergent. Loss spikes and never recovers. Loss continues climbing or oscillates wildly. NaNs appear. Cost: hard restart, sometimes redesign.
Empirically: ~80% of spikes are type 1 (recoverable), ~15% type 2, ~5% type 3 — but this is highly architecture and config dependent.
The dominant mechanism: large-magnitude gradients from rare data¶
The most common spike root cause:
- A rare batch contains tokens or token-sequences extremely under-represented in training so far.
- The model assigns very low probability to those tokens.
- The cross-entropy loss for those tokens is large (e.g. \(-\log(10^{-8}) = 18.4\)).
- The gradient back-propagated is correspondingly large.
- The optimizer (Adam) takes a large step.
- The step pushes the model into a region of parameter space where activations explode or attention probs become degenerate.
- The next batch's loss is enormous → spike.
This mechanism is load-bearing on Adam's β₂. Adam's second-moment \(v\) averages \(\hat{g}^2\) with rate \(1-β_2\) (default 0.999, so \(v\) has effective averaging window of 1000 steps). When a 100× gradient arrives, \(v\) does not update fast enough to dampen the step. The next step is also large because \(v\) is still stale.
The fix (PaLM, Chowdhery 2022): lower β₂ to 0.95. Now \(v\) adapts in ~20 steps, dampening anomalous gradients much faster. Standard for frontier training post-2022.
μP (Maximal Update Parametrization)¶
Yang & Hu 2021 ("μTransfer") propose a re-parametrization in which the optimal learning rate (and other HPs) is invariant to model width. The mechanism:
- Initialize each layer with variance \(\sigma^2 \propto 1/\text{fan\_in}\) (standard).
- Scale the forward output of each layer by \(1/\sqrt{\text{fan\_in}}\) instead of letting variance grow.
- Use a learning rate that scales as \(1/\text{width}\) for weight matrices.
Why this helps stability. Standard parametrization has the property that as you scale model width, activation magnitudes drift. The "right" learning rate at 100M is too small at 1B and too large at 70B. μP fixes the activation magnitude across widths.
The HP-transfer trick (Yang 2022): tune learning rate at small model (e.g. 100M), then apply the same LR to the 70B under μP. Costs ~\(1k of small-model HP search instead of ~\)100k of big-model HP search. This is the principal industrial use of μP — Cerebras, Eleuther, OLMo all use it.
For X1 we do not implement μP — a single 50M run does not benefit. But you should be able to read a config file with mup_base_width=256 and know what it means.
Weight decay and stability¶
Weight decay (the L2 regularizer on weights) interacts subtly with Adam:
- AdamW (Loshchilov 2017): decouples weight decay from the gradient. Weight decay coefficient default 0.1.
- Too low (<0.01): weights drift to large magnitude, attention pre-softmax saturates, gradient flow degrades, eventual instability.
- Too high (>0.5): the model can't fit the data, training plateaus high.
Frontier-lab default: AdamW with wd=0.1, β=(0.9, 0.95), gradient clip at 1.0. Use these for X1.
Gradient clipping¶
Almost universal in transformer pretraining:
When the global gradient L2 norm exceeds 1.0, scale all gradients down to make the norm exactly 1.0. This is the first line of defense against spikes. Without it, a single bad batch can wreck the run.
Default: max_norm=1.0. PaLM used 1.0; Llama-2 used 1.0; OLMo used 1.0. The number is not magic — it's small enough to cap pathological gradients and large enough that healthy gradients are not constrained.
Detection: what to log¶
For X1 lab 00, log every 10 steps:
- Loss (the obvious one).
- Gradient norm (global L2) — pre-clip. A spike here precedes the loss spike by ~1-3 steps.
- Parameter norm (global L2) — drifting upward = instability brewing.
- Adam \(v\) norm — a spike here = β₂ is too high for current data variance.
- Learning rate — easy to forget, easy to miss schedule bugs.
- Tokens-per-second — sudden drops indicate dataloader stalls, NVLink issues, or thermal throttle.
- Activation L2 norm at layer N (sampled) — for late-stage diagnostics.
mlflow (from Phase 18) handles all of this with mlflow.log_metric(name, value, step=).
Recovery: the response playbook¶
If a spike is detected mid-run, the choices are:
Recovery A: continue and hope¶
If loss is recovering on its own (type-1 spike), do nothing. Most spikes resolve. Check 200 steps later. If still recovering, let it ride.
Recovery B: skip-batch¶
Skip the next \(k\) batches and resume from the optimizer state just before the spike. PaLM and others use \(k=5..20\). This works because the spike was caused by that specific batch; skipping it removes the trigger.
Pseudocode:
Recovery C: restart from prior checkpoint¶
For type-2 / type-3 spikes, the optimizer state is corrupted; no amount of skip-batch helps. Restart from a checkpoint ≥1000 steps before the spike.
This is why checkpoint cadence matters. Saving every 1000 steps (a few minutes of compute) gives you ~30 min lost on a restart. Saving every 100k steps gives you 5 hours lost.
X1 lab 00 cadence: every 30 minutes (~50k steps at our throughput). Acceptable for a 24-hour run.
Recovery D: lower the LR¶
If spikes recur across restarts, the LR is too high for this point in training. Drop LR by 2-3× and restart. Cost: re-tune.
Recovery E: data swap¶
If a specific dataset subset is poisoning the run (e.g., a CommonCrawl dump has too much repeated boilerplate), swap it out and continue.
Llama-2 documents doing this; the "mid-training intervention" is a frontier-lab norm.
Numerical precision: bf16 vs fp16 vs fp8¶
X1 uses bf16. The reasons:
- fp16: 5-bit exponent, 10-bit mantissa. Dynamic range too narrow for transformer gradients without loss-scaling (Micikevicius 2017). Loss-scaling is an extra moving piece that can itself cause spikes.
- bf16: 8-bit exponent (same as fp32), 7-bit mantissa. Dynamic range matches fp32; precision is lower. No loss-scaling needed. The pretraining default since 2022.
- fp8: 4 or 5-bit exponent, 2 or 3-bit mantissa. Hopper-era (H100+). Training in fp8 is bleeding-edge; FP8-LM (Peng 2023), Nvidia Transformer Engine paper. Not in X1 scope.
X1 trains in bf16 with fp32 master weights for the optimizer (standard mixed precision). Loss-scaling is not used (bf16 dynamic range covers it).
Architecture-level stability tricks¶
- Pre-LN over post-LN. "Pre-LN" puts LayerNorm before the attention/FFN; "post-LN" puts it after. Pre-LN is much more stable for deep stacks. Modern default. (Wang 2019, "Learning Deep Transformers with Latent Depth").
- RMSNorm over LayerNorm. Slightly faster, no observed stability cost. Llama-1 onward.
- SwiGLU over ReLU/GeLU. Better empirical performance; same stability. Shazeer 2020.
- QK-norm. Normalize queries and keys before the dot product. Used by Chameleon, Idefics-2. Reduces attention-logit explosion.
- Z-loss. Auxiliary loss on the log-partition-function norm. Used by PaLM. Penalizes extreme logits.
X1 uses pre-LN + RMSNorm + SwiGLU (the modern default). No QK-norm, no z-loss — they help past ~1B params more than at 50M.
Mid-training interventions¶
For runs longer than ~10 days, frontier labs intervene: LR resets, data swaps, restarts with different config. Examples:
- Llama-3 paper (Meta 2024): describes multiple LR schedule resets and data curriculum changes mid-run.
- OLMo (Groeneveld 2024): documents 4 mid-training data swaps and 1 architecture tweak.
- BLOOM (BigScience 2022): logged every spike publicly; classic open postmortem reference.
The principle: a pretraining run is not "run and watch." It's a 1-3 month observed system, monitored by 2-5 engineers, with interventions logged like clinical trials.
For X1's 24-hour run, the intervention cadence is much shorter, and the lab includes a spike injection step so you practice the response procedure.
Loss-spike post-mortem template¶
Required for X1 DoD check 3. Structure:
# Spike #N — YYYY-MM-DD HH:MM UTC
## Symptoms
- Pre-spike loss: 3.4
- Peak loss: 8.1
- Recovery loss (500 steps later): 3.6
- Grad norm peak: 47.3 (pre-clip)
## Evidence
- mlflow run: <URI>
- Batch index range: 14,250–14,280
- Token histogram for offending batches: [attached PNG]
- Per-layer activation L2 at spike: [attached log]
## Classification
- Type: 1 (recoverable) / 2 (lingering) / 3 (divergent)
- Root cause: rare-token gradient / β₂ stale / LR too high / data corruption
## Recovery action
- Skip-batch (k=10) / restart-from-ckpt-N / LR drop 2× / data swap
## Outcome
- Loss curve resumed pre-spike trajectory at step 14,800.
- Cost of incident: 550 steps × 500k tokens/s × $1.10/hr / (3600 × 500k) = ~$0.17. Acceptable.
## Lessons / followups
- (e.g.) Lower β₂ from 0.99 to 0.95 next run.
- (e.g.) Add per-batch entropy logging.
What does NOT belong in this theory file¶
- Diverging-and-blame-the-optimizer. "It diverged; we switched to a new optimizer" is a folk myth. The optimizer is rarely the root cause; data and initialization usually are. Resist that narrative.
- The hyperparameter cookbook. This file is mechanisms, not settings. Lab 00 will quote specific HPs.
- The big-budget bake-off. Comparing AdaFactor / Adam / Lion / Shampoo at scale needs a real run; X1's single config does not justify it.
One-paragraph recap¶
Loss spikes happen at scale and are dominantly caused by large gradients from rare data + Adam β₂ being too high to dampen them. The frontier-lab playbook is prevent (β₂=0.95, grad clip 1.0, wd=0.1, pre-LN+RMSNorm+SwiGLU, bf16, μP for cross-scale HP transfer), detect (grad-norm, param-norm, \(v\)-norm, loss logged every 10 steps), recover (do-nothing for type-1, skip-batch for known-cause, restart-from-prior-ckpt for lingering, LR drop or data swap for recurring). X1 lab 00 ships all of this, with a synthetic spike injection so you write the post-mortem.
Next: lab/00-one-day-cloud-pretraining.md.