Skip to content

English · Español

Break — Calcular entropía cruzada vía log(softmax(x)) en lugar de log_softmax

🇪🇸 Una de las trampas más antiguas: tras un softmax estable la probabilidad de una clase improbable puede ser 0 (underflow). Tomar log de 0 da -inf, y todo el batch produce nan en la loss.

Objetivo: implementación de entropía cruzada para el clasificador §A13 de 5 tiempos.

Hipótesis

El aprendiz predice: "Calcular la entropía cruzada como -log(softmax(x)[target]) en lugar de la forma fusionada log_softmax(x)[target] funcionará silenciosamente con logits pequeños, y luego producirá -inf / nan cada vez que algún target tenga una probabilidad softmax que subfluya a 0 en fp32 (es decir, su logit esté más de ~88 por debajo del máximo)."

El break

En tu implementación de entropía cruzada:

 def cross_entropy(logits: np.ndarray, target: int) -> float:
-    log_probs = logits - logsumexp(logits)
-    return -log_probs[target]
+    probs = softmax(logits)
+    return -np.log(probs[target])

Procedimiento de ejecución

Dos casos de prueba, uno en la zona segura, otro al borde del precipicio:

uv run python -c "
import numpy as np
from scipy.special import logsumexp

def softmax(x):
    m = x.max()
    e = np.exp(x - m)
    return e / e.sum()

def ce_unsafe(logits, target):
    p = softmax(logits)
    return -np.log(p[target])

def ce_safe(logits, target):
    return -(logits[target] - logsumexp(logits))

# Case 1: small logits, target = past simple (index 2)
small = np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32)
print(f'small  unsafe={ce_unsafe(small, 2):.6f}  safe={ce_safe(small, 2):.6f}')

# Case 2: huge spread, target = the unlikely class (index 3)
huge = np.array([10.0, 92.0, 30.0, -50.0, 40.0], dtype=np.float32)
print(f'huge   unsafe={ce_unsafe(huge, 3):.6f}  safe={ce_safe(huge, 3):.6f}')
"

Modo de fallo esperado

small  unsafe=2.0871   safe=2.0871           <-- match in safe zone
huge   unsafe=inf      safe=142.000000       <-- underflow disaster

La probabilidad de la clase target en huge es exp(-50 - 92) / Σ ≈ exp(-142) ≈ 1.6e-62, que subfluye a exactamente 0 en fp32. log(0) = -inf. La entropía cruzada se va a +inf. La retropropagación (backpropagation) obtiene gradientes nan. El entrenamiento colapsa en el paso N.

Diagnóstico

Sólo a partir de los logs de entrenamiento:

  1. loss == inf o loss == nan en el paso N es el síntoma sonoro.
  2. El índice de clase del fallo es el del logit relativo más negativo. Imprime argmin(logits) del ejemplo fallido.
  3. Imprime la probabilidad softmax más pequeña por batch. Si es exactamente 0.0 en fp32, estás a un underflow de una entropía cruzada infinita.
  4. Comprueba cruzando contra scipy.special.log_softmax o torch.nn.functional.log_softmax. Ambas están fusionadas y son estables.

Lección

log(softmax(x)) y log_softmax(x) son algebraicamente idénticos. No son numéricamente idénticos: el primero pasa por un valor exp(x_i - max) que puede subfluir; el segundo pasa por x_i - logsumexp(x), donde ambos términos son finitos y la resta es en log-space.

El autograd de la Fase 7 y la entropía cruzada del MLP de la Fase 9 reutilizarán esto. Prefiere siempre la forma fusionada. Si alguna vez ves -log(p) en un notebook, trátalo como un code smell salvo que p esté demostrablemente acotado lejos de 0.

Referencias

  • Los docs de torch.nn.functional.log_softmax de PyTorch lo recomiendan explícitamente sobre log(softmax(...)) por esta razón.
  • Bishop, Pattern Recognition and Machine Learning, §4.3.4 — derivación del softmax-entropía cruzada estable.