English · Español
Break — Compute cross-entropy via log(softmax(x)) instead of log_softmax¶
🇪🇸 Una de las trampas más antiguas: tras un softmax estable la probabilidad de una clase improbable puede ser 0 (underflow). Tomar
logde 0 da-inf, y todo el batch producenanen la loss.
Target: cross-entropy implementation for the §A13 5-tense classifier.
Hypothesis¶
The learner predicts: "Computing cross-entropy as -log(softmax(x)[target]) instead of the fused log_softmax(x)[target] will silently work on small logits, then produce -inf / nan whenever any target has a softmax probability that underflows to 0 in fp32 (i.e., its logit is more than ~88 below the max)."
The break¶
In your cross-entropy implementation:
def cross_entropy(logits: np.ndarray, target: int) -> float:
- log_probs = logits - logsumexp(logits)
- return -log_probs[target]
+ probs = softmax(logits)
+ return -np.log(probs[target])
Run procedure¶
Two test cases, one in the safe zone, one over the cliff:
uv run python -c "
import numpy as np
from scipy.special import logsumexp
def softmax(x):
m = x.max()
e = np.exp(x - m)
return e / e.sum()
def ce_unsafe(logits, target):
p = softmax(logits)
return -np.log(p[target])
def ce_safe(logits, target):
return -(logits[target] - logsumexp(logits))
# Case 1: small logits, target = past simple (index 2)
small = np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32)
print(f'small unsafe={ce_unsafe(small, 2):.6f} safe={ce_safe(small, 2):.6f}')
# Case 2: huge spread, target = the unlikely class (index 3)
huge = np.array([10.0, 92.0, 30.0, -50.0, 40.0], dtype=np.float32)
print(f'huge unsafe={ce_unsafe(huge, 3):.6f} safe={ce_safe(huge, 3):.6f}')
"
Expected failure mode¶
small unsafe=2.0871 safe=2.0871 <-- match in safe zone
huge unsafe=inf safe=142.000000 <-- underflow disaster
The probability of the target class in huge is exp(-50 - 92) / Σ ≈ exp(-142) ≈ 1.6e-62, which underflows to exactly 0 in fp32. log(0) = -inf. Cross-entropy goes to +inf. Backprop gets nan gradients. Training collapses at step N.
Diagnostic¶
From training logs alone:
loss == inforloss == nanat step N is the loud symptom.- The class index of the failure is the one with the most-negative-relative logit. Print
argmin(logits)of the failing example. - Print the smallest softmax probability per batch. If it is
0.0exactly in fp32, you are one underflow away from inf cross-entropy. - Cross-check against
scipy.special.log_softmaxortorch.nn.functional.log_softmax. Both are fused and stable.
Lesson¶
log(softmax(x)) and log_softmax(x) are algebraically identical. They are not numerically identical: the former goes through a value exp(x_i - max) that can underflow; the latter goes through x_i - logsumexp(x), where both terms are finite and the subtraction is in log-space.
Phase 7's autograd and Phase 9's MLP cross-entropy will both reuse this. Always prefer the fused form. If you ever see -log(p) in a notebook, treat it as a code smell unless p is provably bounded away from 0.
References¶
- The PyTorch
torch.nn.functional.log_softmaxdocs explicitly recommend it overlog(softmax(...))for this reason. - Bishop, Pattern Recognition and Machine Learning, §4.3.4 — derivation of stable softmax-cross-entropy.