English · Español
Phase 19 — Stability check decision tree (loss spikes, NaN, mixed precision)¶
🇪🇸 Phase 18 te enseñó a configurar el run. Phase 19 te enseña a leer su dinámica cuando algo se rompe en caliente. Este árbol cubre spikes, NaN, overflow fp16, y divergencia. Antes de invocarlo, asegúrate de haber pasado el árbol de Phase 18.
A runnable checklist for when a Phase-18-configured run starts misbehaving mid-flight. Prerequisite: you walked docs/phase-18-training-loop/stability-check.md and it all passed.
How to use¶
- The run already started healthily. Something went wrong after step \(K\).
- Have
dashboard.htmlandexperiments/<run>/manifest.jsonopen. - Start at §1 (NaN check is fastest). If NaN, go §1 and stop. Otherwise §2 (spike). Otherwise §3 (drift). Otherwise §4 (silent divergence).
§1 — Is loss NaN or Inf?¶
Q1.1 Is loss(step \(t\)) NaN or Inf at any logged step? - If yes: you have a destructive update. Go §1.2. - If no: go §2.
Q1.2 What was the dtype regime in the run?
- fp32: NaN in fp32 is rare. Almost always log(0) in the loss (probability went to exactly 0 due to a giant pre-softmax logit) or 0/0 in the loss reduction. Find the offending op via np.isnan(...).any() checks at each layer in a single forward pass after reload.
- fp16: the activation magnitude exceeded 65504. Go §3 (mixed precision).
- bf16: less common, but the exponent overflow in bf16 still happens at \(\sim 10^{38}\). Same as fp32 path, with the additional possibility that the gradient overflowed (bf16 gradient precision is 7 bits — 1% noise per op).
Q1.3 Where did the NaN originate?
- Reload the last NaN-free checkpoint.
- Step forward one batch with np.isnan checks at: embedding output, each block's attention output, each block's MLP output, final LN, LM-head logits, softmax, log-softmax, loss reduction.
- The first layer that reports NaN with finite input is the culprit.
Q1.4 Fix:
- log(0) in loss: add \(\epsilon\) to softmax denominator, or use log_softmax instead of log(softmax(x)).
- Embedding gradient = \(\infty\): clip-by-value on gradients (in addition to clip-by-norm). Cap at \(\pm 100\).
- fp16 activation overflow: enable dynamic loss scaling (Phase 19 §3) or switch to bf16.
If §1.4 fixes are applied and re-run still NaNs, the bug is structural — go back to Phase 18 stability-check §5 (model architecture).
§2 — Loss spike (recoverable or not)¶
Q2.1 Is there a single step where loss \(\geq 3\sigma\) above the rolling 100-step mean?
- Compute \(\mu = \text{mean}(\text{loss}[t-100:t])\), \(\sigma = \text{std}(\text{loss}[t-100:t])\) for \(t > 100\).
- A spike is
loss[t] > μ + 3σ. - If yes: go §2.2. If no: go §3.
Q2.2 Does loss recover to within \(\mu + 0.5\sigma\) within 50 steps? - Yes (recoverable spike): go §2.3. - No (persistent elevation): go §2.4.
Q2.3 Recoverable spike — find the cause.
- Panel: grad-norm pre-clip at the spike step. If \(> 30 \mu_g\) where \(\mu_g\) is the rolling grad-norm mean: long-tail token in the batch. Open theory/04-loss-spike-postmortem-template.md and follow.
- Panel: LR at the spike step. If discontinuous (jumped): schedule bug.
- Panel: batch composition. If batch index is a duplicate of a recent batch: data-loader bug.
- Apply fix (stratified batching, schedule fix, or loader fix). Re-run.
Q2.4 Persistent elevation — find the cause. - The optimizer's moments are now corrupted; the model is stuck in a worse region. - Reload the last pre-spike checkpoint (you do checkpoint every \(K=200\) steps, right?). - Apply preventative fix (lower clip threshold to 0.5; enable stratified batching) before restart. - If no checkpoint exists, the cost is the full re-train.
§3 — Mixed precision (fp16/bf16) checks¶
Only relevant if dtype is fp16 or bf16 in manifest.json.
Q3.1 Is there a step with grad-norm = inf or NaN?
- Yes: fp16 overflow. The activation magnitude exceeded \(65504\) (the fp16 max).
- No: go §4.
Q3.2 What does the loss-scale history show? - The dynamic loss scaler should be doubling on every \(N\) consecutive non-overflow steps and halving on every overflow. - If loss scale is constant: dynamic scaling isn't engaged. Enable. - If loss scale collapsed to \(< 1\): persistent overflow. The activation magnitudes are not just transient — they are systematically too large. Check init (Phase 18 stability §5), check residual stream magnitude.
Q3.3 fp16 overflow signature.
The diagnostic fingerprint:
| Signal | fp16 overflow | fp32 NaN |
|---|---|---|
| First NaN appears in | gradient or activation | loss directly |
| Loss scale before NaN | finite, then halving | N/A |
| Forward pass NaN-free? | sometimes | rarely |
| Fix | enable / fix loss scaling | architectural |
If unsure: re-run a single step in fp32 with the same data. If it works in fp32, you have an fp16 overflow, not a structural bug.
Q3.4 Fixes:
- Enable torch.cuda.amp.GradScaler (PyTorch native) or your equivalent.
- Reduce LR by 2× if loss-scale keeps collapsing.
- Switch fp16 → bf16 (wider exponent range, less precision; usually safer for transformers).
§4 — Silent divergence¶
Loss is not spiking, not NaN, just slowly going up.
Q4.1 Has loss increased monotonically over the last 200 steps? - Yes: go §4.2. No: the run is healthy (or healthily-noisy); stop walking.
Q4.2 Has LR increased over the same window? - Yes: schedule misconfigured (cosine implemented as inverse cosine, or warmup never ends). - No: go §4.3.
Q4.3 Has any parameter group's weight norm changed by more than 2× in the last 200 steps? - Yes: weight decay is too low (norm climbing) or too high (norm collapsing). Adjust by 2× and re-run from checkpoint. - No: go §4.4.
Q4.4 Is the gradient norm rolling-mean trending up? - Yes: the loss landscape is getting harder (you're moving away from a minimum). Cause is usually upstream — bad data, init drift, or LR too high. Reduce LR by 2×. - No: the model is plateaued. This isn't divergence; this is "you're done training". Check eval metrics — if eval is also plateaued, stop.
Numerical thresholds — quick reference¶
| Symptom | Threshold | First action |
|---|---|---|
| loss NaN/Inf | any | §1 — find originating layer |
| single-step spike | loss > μ + 3σ | §2.3 — check grad-norm panel |
| persistent elevation | spike + no recovery in 50 steps | §2.4 — reload checkpoint |
| grad-norm = inf in fp16 | any | §3 — loss scaling |
| loss increasing 200 steps | monotonic | §4 — check LR, decay, gradient |
| weight norm change | > 2× in 200 steps | §4.3 — adjust decay |
Cross-refs¶
- Phase 18
stability-check.md(configuration-level checks; do those first). theory/04-loss-spike-postmortem-template.md(the post-mortem you write after a spike).theory/03-three-failure-modes.md(the three engineered failures and their signatures).