English · Español
01 — La recurrencia del online softmax¶
🇪🇸 La clave matemática para Flash: poder calcular
softmax(s) @ Vpor trozos sin haber visto todosantes. Mantienes el máximo corriente y la suma corriente, y al añadir un trozo nuevo, reescalas lo anterior porexp(m_viejo - m_nuevo). Una línea de álgebra, todo el resto del fenómeno depende de ella.
El softmax clásico¶
Para un vector s ∈ ℝ^N, el softmax numéricamente estable es:
La resta de m mantiene los argumentos de exp ≤ 0, así que no hay overflow. Esta es la formulación numéricamente estable; la versión inestable (simplemente exp(s) / sum(exp(s))) hace overflow para cualquier s_i > log(float_max) ≈ 88 (fp32) o > log(half_max) ≈ 11 (fp16).
Para la attention con Q, K de dimensión de cabeza d, un único valor pre-softmax s_i = (Q[i, :] · K[j, :]) / √d puede fácilmente superar 11 en fp16. El softmax estable (es decir, restar el máximo) es obligatorio, no opcional, en attention fp16.
El problema: la formulación estándar requiere el vector completo s para calcular m = max(s). Eso impide el cómputo en streaming.
El setup: attention en streaming¶
En FlashAttention, no tenemos la fila completa de s = Q[i, :] @ K^T ∈ ℝ^N de golpe. La tenemos tile-a-tile: trozos s_1, s_2, ..., s_{N/B_c} cada uno de longitud B_c. Para cada trozo s_k, queremos actualizar un O[i, :] ∈ ℝ^d corriente (la salida parcial de attention) tal que, tras consumir todos los trozos, O[i, :] = softmax(s) @ V.
La pregunta: ¿podemos actualizar O correctamente usando sólo el trozo actual y una pequeña cantidad de estado corriente?
Sí. Así.
La recurrencia¶
Mantén tres piezas de estado corriente para la fila de salida i:
m ∈ ℝ— máximo corriente desvisto hasta ahora.ℓ ∈ ℝ— denominador corrientesum(exp(s_seen - m)).O ∈ ℝ^d— salida no normalizada corrientesum_j exp(s_j - m) · V_j.
Cuando llega el trozo s_new ∈ ℝ^{B_c} (con el correspondiente V_new ∈ ℝ^{B_c × d}):
-
Nuevo máximo: $$ m' = \max(m, \max(s_{\text{new}})) $$
-
Reescala el estado viejo al nuevo máximo: $$ \alpha = \exp(m - m') $$ El
ℓyOviejos se calcularon relativos almviejo. Para ponerlos al mismo nivel que el trozo nuevo (que calcularemos relativo am'), multiplica ambos porα: $$ ℓ \leftarrow \alpha \cdot ℓ \qquad O \leftarrow \alpha \cdot O $$ -
Añade la contribución del nuevo trozo: $$ p_{\text{new}} = \exp(s_{\text{new}} - m') \in \mathbb{R}^{B_c} $$ $$ ℓ \leftarrow ℓ + \sum p_{\text{new}} $$ $$ O \leftarrow O + p_{\text{new}} \cdot V_{\text{new}} \quad \text{(producto matriz-vector sobre } B_c \text{ términos)} $$
-
Actualiza
m: $$ m \leftarrow m' $$
Al final, divide una vez:
Esa es la recurrencia entera. Seis líneas de pseudocódigo; un cómputo corriente O(N·d); matemáticamente idéntico al softmax todo-a-la-vez salvo redondeo fp.
Demostración de corrección¶
Afirmación: tras procesar todos los trozos, O / ℓ = softmax(s) @ V.
Sea s = [s_1, ..., s_K] los trozos concatenados (cada uno de longitud B_c). Sea m_global = max(s). Por construcción, tras todos los trozos, m = m_global (el máximo corriente acumula correctamente).
Tras todos los trozos:
- ℓ = sum_{j=1..N} exp(s_j - m_global).
- O = sum_{j=1..N} exp(s_j - m_global) · V_j.
El resultado todo-a-la-vez es O_all = (sum_j exp(s_j - m_global) · V_j) / sum_j exp(s_j - m_global) = O / ℓ. Mismo valor. □
El único paso sutil es la reescala. Supón que hemos procesado los trozos 1..k y estamos a punto de procesar el trozo k+1. Justo antes de procesar k+1, nuestro estado es:
m = max(s_1, ..., s_k). Llámalom_k.ℓ = sum_{j ∈ primeros k trozos} exp(s_j - m_k).O = sum_{j ∈ primeros k trozos} exp(s_j - m_k) · V_j.
Tras observar el trozo k+1, el máximo verdadero es m_{k+1} = max(m_k, max(s_{k+1})). Para rebasar ℓ al nuevo máximo:
ℓ_rebased = sum_{j ∈ primeros k trozos} exp(s_j - m_{k+1})
= sum_{j ∈ primeros k trozos} exp(s_j - m_k + m_k - m_{k+1})
= exp(m_k - m_{k+1}) × sum_{j ∈ primeros k trozos} exp(s_j - m_k)
= α × ℓ
donde α = exp(m_k - m_{k+1}) ≤ 1. Misma álgebra para O. Luego añade las contribuciones del trozo k+1 (que ya están calculadas relativas a m_{k+1}). □
Propiedades numéricas¶
Tres observaciones:
- Sin overflow. Cada argumento de
expes ≤ 0 por construcción (siempre restamos el máximo actual o nuevo antes de exponenciar). Seguro en fp16. - Sin cancelación catastrófica en la reescala.
α = exp(m_k - m_{k+1}) ≤ 1está acotado; multiplicarℓyOpor él sólo puede hacer underflow a cero simsaltó más de~88(fp32) o~11(fp16). Para valores de attention, esto es raro pero posible — la discusión de estabilidad del paper de Flash lo maneja vía acumuladores fp32 param, ℓ, O. - El orden de los trozos no importa. El resultado final es invariante al orden de los trozos (módulo redondeo). Esto es lo que hace que Flash funcione dentro de un kernel tileado: los tiles pueden procesarse en cualquier orden que el scheduler prefiera.
Un ejemplo trabajado¶
Vector s = [1, 2, 3, 10], V = [[1], [1], [1], [1]] (así cada V de trozo es un vector 1×1; la respuesta debería converger a softmax(s) @ V = 1.0 puesto que todas las entradas de V son 1).
Procesar en dos trozos: s_1 = [1, 2], s_2 = [3, 10].
Tras el trozo 1:
- m = 2
- p_1 = exp([1-2, 2-2]) = [exp(-1), 1] ≈ [0.368, 1]
- ℓ = 1.368
- O = 0.368 + 1 = 1.368
Tras el trozo 2:
- max(s_2) = 10, así que m' = max(2, 10) = 10.
- α = exp(2 - 10) = exp(-8) ≈ 3.35e-4
- Reescala: ℓ = 0.000458, O = 0.000458.
- p_2 = exp([3-10, 10-10]) = [exp(-7), 1] ≈ [9.12e-4, 1]
- ℓ = 0.000458 + 0.000912 + 1 = 1.00137
- O = 0.000458 + 0.000912 + 1 = 1.00137
Final: O / ℓ = 1.00137 / 1.00137 = 1.0. ✓
Compara con todo-a-la-vez: softmax([1, 2, 3, 10]) ≈ [1.23e-4, 3.34e-4, 9.08e-4, 0.9985]. Producto punto con [1, 1, 1, 1] = 1.0. Misma respuesta.
Lo que esto permite¶
Una vez tenemos la recurrencia del online softmax, el resto de Flash es "sólo" tiling. Procesamos la matriz de attention S = QK^T un tile a la vez, sin materializar nunca la cosa entera en HBM. Para cada fila de salida de O, mantenemos una cantidad minúscula de estado (m, ℓ, O) y la actualizamos conforme los tiles de K, V van llegando.
La recurrencia es la razón matemática por la que Flash es correcto. El tiling (siguiente archivo de teoría) es la razón algorítmica por la que es rápido.
Problemas de práctica¶
Soluciones al abrir la fase en solutions/01-online-softmax-ref.md. Inténtalo sin código.
- Calcula el online softmax sobre
s = [0, 5]conV = [[2], [3]]troceado comos_1 = [0],s_2 = [5]. Muestra los cuatro pasos. - La recurrencia requiere
α = exp(m_k - m_{k+1}). ¿Bajo qué condicionesαhace underflow a 0 en fp16? ¿Qué falla enℓsi ocurre? (Pista: nada falla — la matemática sigue siendo correcta en el límite.) - Muestra que procesar los trozos en orden inverso produce el mismo
O/ℓfinal (módulo redondeo fp). Esboza el argumento; no simules. - Supón que paralelizas la recurrencia entre
Philos, cada uno manejando1/Pde los trozos, luego reduces al final. ¿Cuál es la operación de reducción? Muestra que es asociativa.
Recap de un párrafo¶
La recurrencia del online softmax mantiene un máximo corriente m, un denominador corriente ℓ, y una salida no normalizada corriente O mientras procesa trozos de s y V uno a uno. Cuando llega un trozo nuevo, reescala los ℓ, O viejos por α = exp(m_old - m_new) para alinearlos con el nuevo máximo, luego añade las contribuciones del trozo nuevo. La respuesta final es O / ℓ. El álgebra es una línea, el cómputo es O(N·d) sin materializar nunca el softmax completo, y el resultado es exacto salvo redondeo de coma flotante. Esta es la clave matemática que hace posible FlashAttention — sin ella, no podrías procesar attention en tiles.
Siguiente: theory/02-flash-attention.md.