English · Español
02 — De ingenuo a tiled: la ruta de optimización de un kernel¶
🇪🇸 Esta página rastrea la trayectoria de optimización de un único kernel — la softmax fusionada sobre el vocabulario gramatical de ~600 formas — desde el primer borrador ingenuo hasta una versión que alcanza ≥30% del peak de cuBLAS /
F.softmax. Cada paso del camino sube el dot en el roofline de Fase 23 y deja una huella mensurable. La lección no es la softmax en particular; es la secuencia de movimientos que sirve para cualquier kernel.
Esta página es el manual de optimización para un kernel. El kernel por el que pasamos es la softmax fusionada sobre la fila de logits del MiniGPT gramatical — shape (B, V) con \(V \approx 600\) a partir del vocabulario §A13 de ~600 formas conjugadas. Borja escribirá cada una de las cuatro versiones en los labs 01–02 y colocará los cuatro puntos en el roofline en el lab 03.
Los números exactos son ilustrativos; lo crítico es la transición — qué compra cada movimiento y qué cuesta en complejidad de código.
El operador¶
Softmax row-wise con el truco de estabilidad numérica:
Tres pasadas lógicas sobre la fila:
- pasada de max — leer \(x_0 \dots x_{V-1}\), encontrar \(m\).
- pasada de sum-exp — leer de nuevo, calcular \(s = \sum \exp(x_k - m)\).
- pasada de normalización — leer una tercera vez, escribir \(y_i = \exp(x_i - m) / s\).
En NumPy esto es una sola línea:
El compilador puede o no fusionar las tres pasadas. En GPU, nosotros las fusionamos — explícitamente — y esa es la optimización.
Aritmética del working set¶
Por fila (fp32, \(V = 600\)):
- Bytes leídos: \(V \cdot 4 = 2400\) B
- Bytes escritos: \(V \cdot 4 = 2400\) B
- FLOPs: ~\(5V = 3000\) FLOPs (un max, \(V\) exps, una suma-recíproco, \(V\) multiplicaciones)
- Intensidad: \(5V / 4V = 1.25\) FLOPs/byte → memory-bound en cualquier GPU (\(I_\text{crit} \geq 4\) para fp64; \(\geq 156\) para Tensor Cores fp16).
El "mejor teórico" es la fracción del ancho de banda HBM que este kernel sostiene. Para \(V = 600\) fp32 y \(B\) en los cientos, ese es nuestro techo.
Versión 1: Ingenua (un hilo por elemento)¶
__global__ void softmax_naive(const float* x, float* y, int V) {
int row = blockIdx.x;
int col = threadIdx.x;
// Pass 1: max (every thread computes the same max — wasteful)
float m = -INFINITY;
for (int k = 0; k < V; ++k) m = fmaxf(m, x[row * V + k]);
// Pass 2: sum-exp
float s = 0.0f;
for (int k = 0; k < V; ++k) s += expf(x[row * V + k] - m);
// Pass 3: write
if (col < V) y[row * V + col] = expf(x[row * V + col] - m) / s;
}
Qué falla:
- Cada hilo relee la fila tres veces. 3× tráfico de memoria.
- Cada hilo calcula el max y la suma independientemente. \(V\)× cómputo redundante.
- Los hilos más allá de \(V\) en el bloque se quedan ociosos. Lanzamiento desperdiciado.
- Sin SMEM. La fila se lee desde HBM cada pasada.
Pero funciona. Ejecútalo, confirma corrección frente a np.softmax a 1e-5, luego optimiza. Nunca tunees un kernel incorrecto.
Esperado: quizás 1–3% del ancho de banda HBM. Mucho margen de mejora.
Versión 2: Coalesced + SMEM¶
Paso 1: cargar la fila en SMEM una vez, con un hilo por elemento, coalesced.
Paso 2: reducir en SMEM para max y sum (el patrón de reducción en árbol de theory/01).
Paso 3: cada hilo escribe su salida.
__global__ void softmax_smem(const float* x, float* y, int V) {
extern __shared__ float row[];
int r = blockIdx.x;
int t = threadIdx.x;
// 1. Coalesced load.
for (int k = t; k < V; k += blockDim.x) row[k] = x[r * V + k];
__syncthreads();
// 2. Reduce for max (tree reduction in SMEM).
// ... (max-reduce omitted; see lab 02)
__shared__ float m;
if (t == 0) {
float mm = -INFINITY;
for (int k = 0; k < V; ++k) mm = fmaxf(mm, row[k]);
m = mm;
}
__syncthreads();
// 3. Compute exp in place, reduce for sum.
for (int k = t; k < V; k += blockDim.x) row[k] = expf(row[k] - m);
__syncthreads();
__shared__ float s;
if (t == 0) {
float ss = 0.0f;
for (int k = 0; k < V; ++k) ss += row[k];
s = ss;
}
__syncthreads();
// 4. Normalize and write (coalesced).
for (int k = t; k < V; k += blockDim.x) y[r * V + k] = row[k] / s;
}
Qué cambió:
- Una lectura coalesced desde HBM a SMEM.
- Una escritura coalesced de vuelta a HBM.
- SMEM mantiene la fila; las pasadas subsiguientes pegan en SMEM a ~10× del ancho de banda HBM.
- Aún tiene reducciones seriales del thread-0 (los bloques
if (t == 0)) — fácil de arreglar a continuación.
Esperado: ~10–20% del pico de ancho de banda HBM. Un salto limpio.
Versión 3: Reducción paralela + online-softmax (una pasada)¶
Dos movimientos más:
- Reemplazar las reducciones seriales del thread-0 con reducciones en árbol a través del bloque (
__syncthreads()en un bucle que halva). Todos los hilos participan. - Fusionar max y sum en una sola pasada con la recurrencia de online-softmax: mantener un max corriente \(m\) y una suma-de-exps corriente \(s\), actualizar ambos cuando un nuevo elemento supere \(m\):
Este es el mismo truco que usa flash attention (Fase 27). Dos lecturas HBM colapsan en una.
Resultado: una sola pasada sobre la fila, después una pasada de normalizar-y-escribir — un viaje de ida y vuelta a HBM menos.
Esperado: ~30–50% del pico de ancho de banda HBM en \(V = 600\). Esta es la versión que alcanza el objetivo de ≥30% de cuBLAS.
Versión 4: Triton (lab 03)¶
El mismo algoritmo, en Triton. La fuente Python es ~30 líneas (comparado con ~80 para el CUDA C tuneado). El autotuner de Triton barre tamaños de bloque; tú especificas el algoritmo, el autotuner encuentra los parámetros.
Esperado en \(V = 600\): 80–95% de la versión CUDA C tuneada a mano. El 5–20% restante es el coste de la generalidad.
La escalera, resumida¶
| Versión | Pasadas HBM | Reducciones | % del pico | Líneas de código |
|---|---|---|---|---|
| Ingenuo | 3 lecturas × \(V\) hilos = \(3V\) efectivo | Ninguna (serial en cada hilo) | 1–3% | ~10 |
| Coalesced + SMEM | 1 lectura + 1 escritura | Serial en thread 0 | 10–20% | ~25 |
| Paralelo + online | 1 lectura + 1 escritura | Árbol en el bloque | 30–50% | ~50 |
| Triton (autotuneado) | Igual | Igual | 25–45% | ~30 (Python) |
La versión CUDA C tuneada alcanza ≥30% de F.softmax. Triton aterriza cerca por detrás. Ambas van al gráfico de roofline.
Qué hizo cada movimiento, en lenguaje de roofline¶
- Coalescing: subió el ancho de banda alcanzable (más bytes por transacción de memoria).
- Cacheo en SMEM: eliminó tráfico HBM redundante, subiendo la intensidad efectiva.
- Reducción paralela: eliminó stalls del scheduler (sin sección serial
if (tid == 0)). - Online softmax (fusionar pasadas): eliminó un viaje de ida y vuelta a HBM — directamente halvó los bytes movidos por fila.
Cada uno de estos es un concepto de la Fase 23 hecho concreto. Ese es el sentido de la fase.
Problemas de drill¶
- Para \(V = 600\) fp32, \(B = 1024\), ¿cuál es el read+write HBM por fila en bytes? ¿Tráfico HBM total para el batch? A 1.55 TB/s (A100 SXM4 40GB), ¿cuánto tarda solo por los bytes?
- ¿Y si \(V = 8192\) (un transformer con vocabulario real)? ¿La estrategia SMEM sigue cabiendo (la SMEM por bloque es ~100 KB máx)? Si no, ¿qué cambia?
- ¿Por qué online-softmax ayuda más a fp16 que a fp32? (Pista: el ancho de banda HBM es por byte; halvar bytes → halvar tiempo. El pico de cómputo fp16 ya iba sobrado.)
- ¿Dónde ayudarían los Tensor Cores en este kernel? (Pista: no ayudarían — softmax es elementwise + reducciones, no matmul. Los Tensor Cores ayudarían si fusionáramos softmax con el GEMM del LM-head — eso es tema de la Fase 27.)
Nota sobre el fallback de CPU (para desarrollo local)¶
La máquina local de Borja no tiene CUDA. La ruta de fallback de CPU es justo el código NumPy mostrado en §1; el dispatcher en src/minikernel/dispatch.py (lab 02) decide entre el kernel CUDA C, el kernel Triton y np.exp(x - x.max(...)). Los tests de equivalencia numérica corren localmente contra la referencia NumPy; los tests de rendimiento corren solo en la GPU en la nube.
Esto mantiene la iteración barata — los bugs de nivel algorítmico afloran en CPU en milisegundos. La GPU en la nube solo se enciende cuando el algoritmo es correcto.
Lo que ahora deberías ser capaz de hacer¶
- Esbozar las cuatro versiones del kernel de softmax de memoria; explicar qué cuesta y qué compra cada movimiento.
- Predecir aproximadamente dónde aterriza cada punto en el roofline.
- Aplicar la misma escalera a un operador nuevo (p. ej., layernorm, RMSNorm) — los movimientos son reutilizables.
- Decidir si el cuello de botella de un kernel es el ancho de banda HBM, conflictos de bancos de SMEM, presión de registros o stalls del scheduler.
Lo que esta página NO cubre¶
- Detalles del tiling de GEMM (estilo cuBLAS). Un kernel GEMM tiene una escalera de optimización diferente (tile en registro + tile en SMEM + Tensor Core). Ruta alternativa de kernel de la Fase 24; el kernel por defecto aquí es softmax.
- Flash attention. Fase 27. El truco de online-softmax mencionado aquí es uno de los bloques de construcción de flash attention pero solo uno.
- Evitar conflictos de bancos. Brevemente relevante para reducciones en SMEM; el lab 02 lo demuestra con un perfil.
- Espacios de autotuning. La superficie de autotune de Triton es
theory/03.
Siguiente: theory/03-triton.md — el lenguaje de kernels DSL en Python, cómo funciona su autotune, cuándo bate a CUDA C y cuándo no.