Skip to content

English · Español

04 — Log-sum-exp and numerical stability

🇪🇸 La matemática es bella, pero los exponentes se desbordan. Esta es la última pieza: cómo calcular softmax y cross-entropy sin que numerosos detalles te dejen con NaN.

The problem

The softmax of logits \(z \in \mathbb{R}^V\) is:

\[\sigma(z)_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}.\]

Two failure modes in float32:

  1. Overflow. If \(\max_i z_i \gg 88\), \(\exp(z_i)\) overflows to \(+\infty\).
  2. Underflow. If \(z_i \ll -88\), \(\exp(z_i)\) underflows to 0; the corresponding \(p_i\) is 0 even when it should be \(\sim 10^{-40}\).

Both happen routinely. Modern transformer logits can be in the hundreds (especially after temperature scaling) and the negative log-likelihood loss computed via naïve softmax + log will silently produce NaN or 0-gradient when this happens.

The log-sum-exp identity

For any constant \(c \in \mathbb{R}\):

\[\log \sum_j \exp(z_j) = c + \log \sum_j \exp(z_j - c).\]

Choose \(c = \max_j z_j\). Then the maximum exponent inside the sum is \(\exp(0) = 1\) (no overflow), and the smallest is \(\exp(z_{\min} - z_{\max})\), which can underflow to 0 but only contributes negligibly to the sum.

Stable log-sum-exp:

def logsumexp(z):
    c = z.max()
    return c + np.log(np.exp(z - c).sum())

Stable log-softmax

\[\log \sigma(z)_i = z_i - \log \sum_j \exp(z_j) = (z_i - c) - \log \sum_j \exp(z_j - c).\]
def log_softmax(z):
    c = z.max()
    z_shifted = z - c
    return z_shifted - np.log(np.exp(z_shifted).sum())

Borja implements this in lab 02. Compare to naïve \(\log(\sigma(z))\) — they should agree on well-conditioned inputs and disagree (one being correct, one NaN) on adversarial inputs like z = [1000, 1001, 1002].

Stable cross-entropy from logits

For a single example with true label \(y^*\):

\[\mathcal{L} = -\log \sigma(z)_{y^*} = -z_{y^*} + \log \sum_j \exp(z_j) = -z_{y^*} + \text{logsumexp}(z).\]

Crucially: do not compute softmax first and then take log of the result. Always go logits → log-softmax in one step. NumPy implementation:

def cross_entropy_logits(z, y_star):
    log_probs = log_softmax(z)
    return -log_probs[y_star]

Vectorised over a batch of size \(B\):

def cross_entropy_batch(Z, Y):  # Z: (B, V), Y: (B,)
    c = Z.max(axis=1, keepdims=True)
    Z_shifted = Z - c
    log_probs = Z_shifted - np.log(np.exp(Z_shifted).sum(axis=1, keepdims=True))
    return -log_probs[np.arange(B), Y].mean()

The gradient (foreshadowing Phase 07)

We'll derive this in detail in Phase 07, but it's worth seeing now: the gradient of CE-from-logits with respect to the logits is:

\[\frac{\partial \mathcal{L}}{\partial z_i} = \sigma(z)_i - \mathbb{1}[i = y^*] = q_i - \delta_{i, y^*}.\]

This is elegant and numerically stable: no division by tiny probabilities, no logs of small numbers. It's also the reason every framework fuses log_softmax + nll_loss into a single op (cross_entropy_loss) — both for speed and for stability.

The label-smoothing variant

A common variant uses soft labels: instead of \(\mathbb{1}[i = y^*]\), target \(p_i = (1 - \alpha) \mathbb{1}[i = y^*] + \alpha / V\) for some smoothing \(\alpha \in [0, 1)\). The CE becomes:

\[\mathcal{L} = -(1 - \alpha) \log q_{y^*} - \frac{\alpha}{V} \sum_i \log q_i.\]

Borja doesn't have to implement label smoothing in Phase 05 (it lands in Phase 18 as an optional regularisation), but it's worth seeing the form.

Numerical pitfalls catalogue

Bug Symptom Cure
Naïve softmax with large logits inf or NaN in probabilities Stable log-softmax (above)
Cross-entropy computed as -(p * log(q)).sum() with q containing zeros inf loss Compute from logits, not probs
Forgetting to mask padded positions in sequence CE Loss includes padding tokens Multiply per-token loss by attention mask before averaging
Forgetting to .detach() the target when target has gradients (rare but real) Wrong gradient Use one-hot or index labels, not soft labels with gradients
Computing per-batch loss as sum instead of mean Loss scales with batch size; LR-coupled Average by token-count, not by sequence-count
Using FP16 throughout cross-entropy Underflow in log-softmax Promote to FP32 for the log-softmax + loss; keep FP16 for the matmul

The last item is a classic — mixed precision training (Phase 18) requires this exact upcasting pattern.

Test discipline

Borja's lab 02 must include:

  1. Property test: log_softmax(z) ≈ log_softmax(z + c) for any scalar \(c\) (shift invariance).
  2. Reference test: compare against scipy.special.log_softmax for a battery of inputs.
  3. Stress test: inputs like z = [1000, 1001, 1002] and z = [-1000, -999, -998] produce finite, sane outputs.

These tests should land in tests/test_logsoftmax.py and stay green for the rest of the project. Phase 19 (training-dynamics debugging) will reference this test discipline.

What this file does NOT cover

  • Numerical analysis of higher-order derivatives (rarely needed).
  • Tropical / softmax-related semiring algebra (theoretically beautiful, practically not load-bearing).
  • Mixed-precision implementation details (handled in Phase 18).

Next: ../lab/00-entropy-by-hand.md