Skip to content

English · Español

Phase 19 — Stability check decision tree (loss spikes, NaN, mixed precision)

🇪🇸 Phase 18 te enseñó a configurar el run. Phase 19 te enseña a leer su dinámica cuando algo se rompe en caliente. Este árbol cubre spikes, NaN, overflow fp16, y divergencia. Antes de invocarlo, asegúrate de haber pasado el árbol de Phase 18.


A runnable checklist for when a Phase-18-configured run starts misbehaving mid-flight. Prerequisite: you walked docs/phase-18-training-loop/stability-check.md and it all passed.

How to use

  1. The run already started healthily. Something went wrong after step \(K\).
  2. Have dashboard.html and experiments/<run>/manifest.json open.
  3. Start at §1 (NaN check is fastest). If NaN, go §1 and stop. Otherwise §2 (spike). Otherwise §3 (drift). Otherwise §4 (silent divergence).

§1 — Is loss NaN or Inf?

Q1.1 Is loss(step \(t\)) NaN or Inf at any logged step? - If yes: you have a destructive update. Go §1.2. - If no: go §2.

Q1.2 What was the dtype regime in the run? - fp32: NaN in fp32 is rare. Almost always log(0) in the loss (probability went to exactly 0 due to a giant pre-softmax logit) or 0/0 in the loss reduction. Find the offending op via np.isnan(...).any() checks at each layer in a single forward pass after reload. - fp16: the activation magnitude exceeded 65504. Go §3 (mixed precision). - bf16: less common, but the exponent overflow in bf16 still happens at \(\sim 10^{38}\). Same as fp32 path, with the additional possibility that the gradient overflowed (bf16 gradient precision is 7 bits — 1% noise per op).

Q1.3 Where did the NaN originate? - Reload the last NaN-free checkpoint. - Step forward one batch with np.isnan checks at: embedding output, each block's attention output, each block's MLP output, final LN, LM-head logits, softmax, log-softmax, loss reduction. - The first layer that reports NaN with finite input is the culprit.

Q1.4 Fix: - log(0) in loss: add \(\epsilon\) to softmax denominator, or use log_softmax instead of log(softmax(x)). - Embedding gradient = \(\infty\): clip-by-value on gradients (in addition to clip-by-norm). Cap at \(\pm 100\). - fp16 activation overflow: enable dynamic loss scaling (Phase 19 §3) or switch to bf16.

If §1.4 fixes are applied and re-run still NaNs, the bug is structural — go back to Phase 18 stability-check §5 (model architecture).


§2 — Loss spike (recoverable or not)

Q2.1 Is there a single step where loss \(\geq 3\sigma\) above the rolling 100-step mean?

  • Compute \(\mu = \text{mean}(\text{loss}[t-100:t])\), \(\sigma = \text{std}(\text{loss}[t-100:t])\) for \(t > 100\).
  • A spike is loss[t] > μ + 3σ.
  • If yes: go §2.2. If no: go §3.

Q2.2 Does loss recover to within \(\mu + 0.5\sigma\) within 50 steps? - Yes (recoverable spike): go §2.3. - No (persistent elevation): go §2.4.

Q2.3 Recoverable spike — find the cause. - Panel: grad-norm pre-clip at the spike step. If \(> 30 \mu_g\) where \(\mu_g\) is the rolling grad-norm mean: long-tail token in the batch. Open theory/04-loss-spike-postmortem-template.md and follow. - Panel: LR at the spike step. If discontinuous (jumped): schedule bug. - Panel: batch composition. If batch index is a duplicate of a recent batch: data-loader bug. - Apply fix (stratified batching, schedule fix, or loader fix). Re-run.

Q2.4 Persistent elevation — find the cause. - The optimizer's moments are now corrupted; the model is stuck in a worse region. - Reload the last pre-spike checkpoint (you do checkpoint every \(K=200\) steps, right?). - Apply preventative fix (lower clip threshold to 0.5; enable stratified batching) before restart. - If no checkpoint exists, the cost is the full re-train.


§3 — Mixed precision (fp16/bf16) checks

Only relevant if dtype is fp16 or bf16 in manifest.json.

Q3.1 Is there a step with grad-norm = inf or NaN? - Yes: fp16 overflow. The activation magnitude exceeded \(65504\) (the fp16 max). - No: go §4.

Q3.2 What does the loss-scale history show? - The dynamic loss scaler should be doubling on every \(N\) consecutive non-overflow steps and halving on every overflow. - If loss scale is constant: dynamic scaling isn't engaged. Enable. - If loss scale collapsed to \(< 1\): persistent overflow. The activation magnitudes are not just transient — they are systematically too large. Check init (Phase 18 stability §5), check residual stream magnitude.

Q3.3 fp16 overflow signature.

The diagnostic fingerprint:

Signal fp16 overflow fp32 NaN
First NaN appears in gradient or activation loss directly
Loss scale before NaN finite, then halving N/A
Forward pass NaN-free? sometimes rarely
Fix enable / fix loss scaling architectural

If unsure: re-run a single step in fp32 with the same data. If it works in fp32, you have an fp16 overflow, not a structural bug.

Q3.4 Fixes: - Enable torch.cuda.amp.GradScaler (PyTorch native) or your equivalent. - Reduce LR by 2× if loss-scale keeps collapsing. - Switch fp16 → bf16 (wider exponent range, less precision; usually safer for transformers).


§4 — Silent divergence

Loss is not spiking, not NaN, just slowly going up.

Q4.1 Has loss increased monotonically over the last 200 steps? - Yes: go §4.2. No: the run is healthy (or healthily-noisy); stop walking.

Q4.2 Has LR increased over the same window? - Yes: schedule misconfigured (cosine implemented as inverse cosine, or warmup never ends). - No: go §4.3.

Q4.3 Has any parameter group's weight norm changed by more than 2× in the last 200 steps? - Yes: weight decay is too low (norm climbing) or too high (norm collapsing). Adjust by 2× and re-run from checkpoint. - No: go §4.4.

Q4.4 Is the gradient norm rolling-mean trending up? - Yes: the loss landscape is getting harder (you're moving away from a minimum). Cause is usually upstream — bad data, init drift, or LR too high. Reduce LR by 2×. - No: the model is plateaued. This isn't divergence; this is "you're done training". Check eval metrics — if eval is also plateaued, stop.


Numerical thresholds — quick reference

Symptom Threshold First action
loss NaN/Inf any §1 — find originating layer
single-step spike loss > μ + 3σ §2.3 — check grad-norm panel
persistent elevation spike + no recovery in 50 steps §2.4 — reload checkpoint
grad-norm = inf in fp16 any §3 — loss scaling
loss increasing 200 steps monotonic §4 — check LR, decay, gradient
weight norm change > 2× in 200 steps §4.3 — adjust decay

Cross-refs

  • Phase 18 stability-check.md (configuration-level checks; do those first).
  • theory/04-loss-spike-postmortem-template.md (the post-mortem you write after a spike).
  • theory/03-three-failure-modes.md (the three engineered failures and their signatures).