English · Español
04 — Log-sum-exp and numerical stability¶
🇪🇸 La matemática es bella, pero los exponentes se desbordan. Esta es la última pieza: cómo calcular softmax y cross-entropy sin que numerosos detalles te dejen con NaN.
The problem¶
The softmax of logits \(z \in \mathbb{R}^V\) is:
Two failure modes in float32:
- Overflow. If \(\max_i z_i \gg 88\), \(\exp(z_i)\) overflows to \(+\infty\).
- Underflow. If \(z_i \ll -88\), \(\exp(z_i)\) underflows to 0; the corresponding \(p_i\) is 0 even when it should be \(\sim 10^{-40}\).
Both happen routinely. Modern transformer logits can be in the hundreds (especially after temperature scaling) and the negative log-likelihood loss computed via naïve softmax + log will silently produce NaN or 0-gradient when this happens.
The log-sum-exp identity¶
For any constant \(c \in \mathbb{R}\):
Choose \(c = \max_j z_j\). Then the maximum exponent inside the sum is \(\exp(0) = 1\) (no overflow), and the smallest is \(\exp(z_{\min} - z_{\max})\), which can underflow to 0 but only contributes negligibly to the sum.
Stable log-sum-exp:
Stable log-softmax¶
def log_softmax(z):
c = z.max()
z_shifted = z - c
return z_shifted - np.log(np.exp(z_shifted).sum())
Borja implements this in lab 02. Compare to naïve \(\log(\sigma(z))\) — they should agree on well-conditioned inputs and disagree (one being correct, one NaN) on adversarial inputs like z = [1000, 1001, 1002].
Stable cross-entropy from logits¶
For a single example with true label \(y^*\):
Crucially: do not compute softmax first and then take log of the result. Always go logits → log-softmax in one step. NumPy implementation:
Vectorised over a batch of size \(B\):
def cross_entropy_batch(Z, Y): # Z: (B, V), Y: (B,)
c = Z.max(axis=1, keepdims=True)
Z_shifted = Z - c
log_probs = Z_shifted - np.log(np.exp(Z_shifted).sum(axis=1, keepdims=True))
return -log_probs[np.arange(B), Y].mean()
The gradient (foreshadowing Phase 07)¶
We'll derive this in detail in Phase 07, but it's worth seeing now: the gradient of CE-from-logits with respect to the logits is:
This is elegant and numerically stable: no division by tiny probabilities, no logs of small numbers. It's also the reason every framework fuses log_softmax + nll_loss into a single op (cross_entropy_loss) — both for speed and for stability.
The label-smoothing variant¶
A common variant uses soft labels: instead of \(\mathbb{1}[i = y^*]\), target \(p_i = (1 - \alpha) \mathbb{1}[i = y^*] + \alpha / V\) for some smoothing \(\alpha \in [0, 1)\). The CE becomes:
Borja doesn't have to implement label smoothing in Phase 05 (it lands in Phase 18 as an optional regularisation), but it's worth seeing the form.
Numerical pitfalls catalogue¶
| Bug | Symptom | Cure |
|---|---|---|
| Naïve softmax with large logits | inf or NaN in probabilities |
Stable log-softmax (above) |
Cross-entropy computed as -(p * log(q)).sum() with q containing zeros |
inf loss |
Compute from logits, not probs |
| Forgetting to mask padded positions in sequence CE | Loss includes padding tokens | Multiply per-token loss by attention mask before averaging |
Forgetting to .detach() the target when target has gradients (rare but real) |
Wrong gradient | Use one-hot or index labels, not soft labels with gradients |
Computing per-batch loss as sum instead of mean |
Loss scales with batch size; LR-coupled | Average by token-count, not by sequence-count |
| Using FP16 throughout cross-entropy | Underflow in log-softmax | Promote to FP32 for the log-softmax + loss; keep FP16 for the matmul |
The last item is a classic — mixed precision training (Phase 18) requires this exact upcasting pattern.
Test discipline¶
Borja's lab 02 must include:
- Property test:
log_softmax(z) ≈ log_softmax(z + c)for any scalar \(c\) (shift invariance). - Reference test: compare against
scipy.special.log_softmaxfor a battery of inputs. - Stress test: inputs like
z = [1000, 1001, 1002]andz = [-1000, -999, -998]produce finite, sane outputs.
These tests should land in tests/test_logsoftmax.py and stay green for the rest of the project. Phase 19 (training-dynamics debugging) will reference this test discipline.
What this file does NOT cover¶
- Numerical analysis of higher-order derivatives (rarely needed).
- Tropical / softmax-related semiring algebra (theoretically beautiful, practically not load-bearing).
- Mixed-precision implementation details (handled in Phase 18).