Skip to content

English · Español

Break — Compute cross-entropy via log(softmax(x)) instead of log_softmax

🇪🇸 Una de las trampas más antiguas: tras un softmax estable la probabilidad de una clase improbable puede ser 0 (underflow). Tomar log de 0 da -inf, y todo el batch produce nan en la loss.

Target: cross-entropy implementation for the §A13 5-tense classifier.

Hypothesis

The learner predicts: "Computing cross-entropy as -log(softmax(x)[target]) instead of the fused log_softmax(x)[target] will silently work on small logits, then produce -inf / nan whenever any target has a softmax probability that underflows to 0 in fp32 (i.e., its logit is more than ~88 below the max)."

The break

In your cross-entropy implementation:

 def cross_entropy(logits: np.ndarray, target: int) -> float:
-    log_probs = logits - logsumexp(logits)
-    return -log_probs[target]
+    probs = softmax(logits)
+    return -np.log(probs[target])

Run procedure

Two test cases, one in the safe zone, one over the cliff:

uv run python -c "
import numpy as np
from scipy.special import logsumexp

def softmax(x):
    m = x.max()
    e = np.exp(x - m)
    return e / e.sum()

def ce_unsafe(logits, target):
    p = softmax(logits)
    return -np.log(p[target])

def ce_safe(logits, target):
    return -(logits[target] - logsumexp(logits))

# Case 1: small logits, target = past simple (index 2)
small = np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32)
print(f'small  unsafe={ce_unsafe(small, 2):.6f}  safe={ce_safe(small, 2):.6f}')

# Case 2: huge spread, target = the unlikely class (index 3)
huge = np.array([10.0, 92.0, 30.0, -50.0, 40.0], dtype=np.float32)
print(f'huge   unsafe={ce_unsafe(huge, 3):.6f}  safe={ce_safe(huge, 3):.6f}')
"

Expected failure mode

small  unsafe=2.0871   safe=2.0871           <-- match in safe zone
huge   unsafe=inf      safe=142.000000       <-- underflow disaster

The probability of the target class in huge is exp(-50 - 92) / Σ ≈ exp(-142) ≈ 1.6e-62, which underflows to exactly 0 in fp32. log(0) = -inf. Cross-entropy goes to +inf. Backprop gets nan gradients. Training collapses at step N.

Diagnostic

From training logs alone:

  1. loss == inf or loss == nan at step N is the loud symptom.
  2. The class index of the failure is the one with the most-negative-relative logit. Print argmin(logits) of the failing example.
  3. Print the smallest softmax probability per batch. If it is 0.0 exactly in fp32, you are one underflow away from inf cross-entropy.
  4. Cross-check against scipy.special.log_softmax or torch.nn.functional.log_softmax. Both are fused and stable.

Lesson

log(softmax(x)) and log_softmax(x) are algebraically identical. They are not numerically identical: the former goes through a value exp(x_i - max) that can underflow; the latter goes through x_i - logsumexp(x), where both terms are finite and the subtraction is in log-space.

Phase 7's autograd and Phase 9's MLP cross-entropy will both reuse this. Always prefer the fused form. If you ever see -log(p) in a notebook, treat it as a code smell unless p is provably bounded away from 0.

References

  • The PyTorch torch.nn.functional.log_softmax docs explicitly recommend it over log(softmax(...)) for this reason.
  • Bishop, Pattern Recognition and Machine Learning, §4.3.4 — derivation of stable softmax-cross-entropy.