English · Español
Break — Calcular entropía cruzada vía log(softmax(x)) en lugar de log_softmax¶
🇪🇸 Una de las trampas más antiguas: tras un softmax estable la probabilidad de una clase improbable puede ser 0 (underflow). Tomar
logde 0 da-inf, y todo el batch producenanen la loss.
Objetivo: implementación de entropía cruzada para el clasificador §A13 de 5 tiempos.
Hipótesis¶
El aprendiz predice: "Calcular la entropía cruzada como -log(softmax(x)[target]) en lugar de la forma fusionada log_softmax(x)[target] funcionará silenciosamente con logits pequeños, y luego producirá -inf / nan cada vez que algún target tenga una probabilidad softmax que subfluya a 0 en fp32 (es decir, su logit esté más de ~88 por debajo del máximo)."
El break¶
En tu implementación de entropía cruzada:
def cross_entropy(logits: np.ndarray, target: int) -> float:
- log_probs = logits - logsumexp(logits)
- return -log_probs[target]
+ probs = softmax(logits)
+ return -np.log(probs[target])
Procedimiento de ejecución¶
Dos casos de prueba, uno en la zona segura, otro al borde del precipicio:
uv run python -c "
import numpy as np
from scipy.special import logsumexp
def softmax(x):
m = x.max()
e = np.exp(x - m)
return e / e.sum()
def ce_unsafe(logits, target):
p = softmax(logits)
return -np.log(p[target])
def ce_safe(logits, target):
return -(logits[target] - logsumexp(logits))
# Case 1: small logits, target = past simple (index 2)
small = np.array([1.2, 4.7, 3.1, 0.5, 2.9], dtype=np.float32)
print(f'small unsafe={ce_unsafe(small, 2):.6f} safe={ce_safe(small, 2):.6f}')
# Case 2: huge spread, target = the unlikely class (index 3)
huge = np.array([10.0, 92.0, 30.0, -50.0, 40.0], dtype=np.float32)
print(f'huge unsafe={ce_unsafe(huge, 3):.6f} safe={ce_safe(huge, 3):.6f}')
"
Modo de fallo esperado¶
small unsafe=2.0871 safe=2.0871 <-- match in safe zone
huge unsafe=inf safe=142.000000 <-- underflow disaster
La probabilidad de la clase target en huge es exp(-50 - 92) / Σ ≈ exp(-142) ≈ 1.6e-62, que subfluye a exactamente 0 en fp32. log(0) = -inf. La entropía cruzada se va a +inf. La retropropagación (backpropagation) obtiene gradientes nan. El entrenamiento colapsa en el paso N.
Diagnóstico¶
Sólo a partir de los logs de entrenamiento:
loss == infoloss == nanen el paso N es el síntoma sonoro.- El índice de clase del fallo es el del logit relativo más negativo. Imprime
argmin(logits)del ejemplo fallido. - Imprime la probabilidad softmax más pequeña por batch. Si es exactamente
0.0en fp32, estás a un underflow de una entropía cruzada infinita. - Comprueba cruzando contra
scipy.special.log_softmaxotorch.nn.functional.log_softmax. Ambas están fusionadas y son estables.
Lección¶
log(softmax(x)) y log_softmax(x) son algebraicamente idénticos. No son numéricamente idénticos: el primero pasa por un valor exp(x_i - max) que puede subfluir; el segundo pasa por x_i - logsumexp(x), donde ambos términos son finitos y la resta es en log-space.
El autograd de la Fase 7 y la entropía cruzada del MLP de la Fase 9 reutilizarán esto. Prefiere siempre la forma fusionada. Si alguna vez ves -log(p) en un notebook, trátalo como un code smell salvo que p esté demostrablemente acotado lejos de 0.
Referencias¶
- Los docs de
torch.nn.functional.log_softmaxde PyTorch lo recomiendan explícitamente sobrelog(softmax(...))por esta razón. - Bishop, Pattern Recognition and Machine Learning, §4.3.4 — derivación del softmax-entropía cruzada estable.