English · Español
Lab 00 — Softmax online en Python puro¶
Objetivo: implementar la recurrencia del softmax online y verificar que coincide con el softmax por lotes sobre datos sintéticos, tanto en fp32 como en fp16.
Tiempo estimado: 2–3 horas.
Prerrequisito: theory 01 leída y la recurrencia rederivable de memoria.
Lo que produces¶
Un directorio experiments/27-online-softmax/ que contenga:
online_softmax.py— implementación en Python puro (NumPy). Dos funciones:softmax_batched(formulación clásica) ysoftmax_online_chunked(chunks, V_chunks, m_init=-inf, ℓ_init=0, O_init=0).test_equivalence.py— verifica que las dos producen salidas idénticas sobre entradas aleatorias.results.json— medidas del max-abs-error entre {fp32, fp16, bf16} y entre tamaños de chunk {1, 4, 16, 64}.manifest.json.README.md— interpretación.
No hay entregable en src/ para este lab; la implementación es código pedagógico deliberadamente desechable.
TODOs¶
Bloque A — implementar softmax por lotes como referencia¶
-
softmax_batched(s: ndarray, V: ndarray) -> ndarray: devuelvesoftmax(s) @ V(s1D,V2D de forma(N, d)). Usa la forma numéricamente estable.
Bloque B — implementar la versión online¶
- Firma:
softmax_online(chunks: list[ndarray], V_chunks: list[ndarray]) -> ndarray. - Bucle sobre chunks; mantén
m, ℓ, Osegún la recurrencia de theory 01. - Devuelve
O / ℓfinal.
Bloque C — verificar la equivalencia¶
Para cada (N, d, dtype, chunk_size) en una rejilla pequeña:
- Genera
s ∈ ℝ^N,V ∈ ℝ^{N × d}aleatorios. - Calcula las versiones por lotes y online.
- Registra
max_abs_error = max|O_batched - O_online|.
Rangos esperados:
- fp32: < 1e-6 (sólo redondeo).
- fp16: < 1e-3.
- bf16: < 5e-3.
Bloque D — sensibilidad al tamaño de chunk¶
Para fp16:
- Barre chunk_size ∈ {1, 4, 16, 64, 256}. Plotea max_abs_error vs chunk_size.
- Esperado: el error es aproximadamente constante entre tamaños de chunk (la recurrencia es exacta algebraicamente; sólo importa el redondeo, y el redondeo está acotado por
O(N)independientemente del tamaño de chunk). - Si el error varía bruscamente con chunk_size, tu reescalado
αtiene un bug.
Bloque E — entradas patológicas¶
Prueba la recurrencia sobre:
-
s = [60, 0, 0, 0](un valor enorme seguido de otros minúsculos). En fp16,exp(60)desborda — pero el softmax online debería manejarlo gracias a la sustracción del máximo móvil. -
s = [-100, -100, -100, 0](uno casi-cero, el resto muy negativos). Verifica quesoftmax_onlinedéO ≈ V[3]. -
s = todo ceros. Atención uniforme; la salida esmean(V).
Bloque E' — entradas realistas del corpus de verbos¶
El vocabulario del corpus de verbos es pequeño (~600 formas), así que los logits de atención sobre una secuencia de 64 tokens tienen una distribución muy picuda tras unas pocas épocas de entrenamiento (el modelo está seguro de cada verbo en contexto).
- Genera
smuestreando deN(0, 5)(una distribución picuda-pero-no-patológica que imita los logits de atención post-entrenamiento sobre un vocab pequeño). - Establece
Vcon forma(64, 64)(N = 64coincide con la longitud de secuencia del corpus de verbos;d = 64es una dimensión de cabeza típica para MiniGPT). - Verifica que
softmax_onlinecon chunk_size=16 coincide con el softmax por lotes dentro de la tolerancia fp16. - Comenta en
README.md: ¿cómo ayuda o perjudica la distribución picuda a la estabilidad de la recurrencia online?
Bloque F — interpretar en README.md¶
Tres preguntas:
- ¿Cuál es el peor error fp16 que observaste? ¿Está bajo
1e-3? Si no, ¿de dónde vino el error adicional? - ¿El tamaño de chunk afecta a la precisión? Si observaste sensibilidad, ¿por qué?
- ¿Qué ocurre en fp16 cuando
scontiene un valor > 11? (exp(11) ≈ half_max.) ¿Maneja la recurrencia online el caso, o necesitas acumuladores fp32 param, ℓ, O?
Condiciones de parada¶
- Los cinco archivos commiteados.
- max_abs_error fp32 <
1e-6; fp16 <1e-3. - README responde las tres preguntas del Bloque F.
Errores típicos¶
- Off-by-one en el bucle de chunks. Si los chunks no teselan perfectamente
s, el manejo del último chunk puede descartar entradas. Usa una lista de chunks en Python en lugar de aritmética de índices para mayor claridad. - Reescalado
αaplicado a lo que no toca. TantoℓcomoOnecesitanα. Si olvidas uno, los errores se acumulan cuadráticamente con N. exp(m - m')subdesborda. Cuandom_new ≫ m_old,αpuede subdesbordar. Esto es matemáticamente fino (α → 0sólo significa que las contribuciones antiguas quedan eclipsadas por las nuevas), pero si calculasℓ * αen fp16 yℓera pequeño, puedes perderlo del todo. Usa acumuladores fp32 para el estado móvil en implementaciones fp16.
Cuándo consultar solutions/¶
Tras cumplir todas las condiciones de parada. solutions/00-online-softmax-ref.md (apertura de fase) compara estructura y números.
Siguiente lab: lab/01-flash-bytes.md.