English · Español
02 — From Naive to Tiled: the Optimization Path of One Kernel¶
🇪🇸 Esta página rastrea la trayectoria de optimización de un único kernel — la softmax fusionada sobre el vocabulario gramatical de ~600 formas — desde el primer borrador ingenuo hasta una versión que alcanza ≥30% del peak de cuBLAS /
F.softmax. Cada paso del camino sube el dot en el roofline de Fase 23 y deja una huella mensurable. La lección no es la softmax en particular; es la secuencia de movimientos que sirve para cualquier kernel.
This page is the optimization manual for one kernel. The kernel we walk through is the fused softmax over the grammar MiniGPT's logit row — shape (B, V) with \(V \approx 600\) from §A13's vocabulary of ~600 conjugated forms. Borja will write each of the four versions in lab 01–02 and place the four dots on the roofline in lab 03.
The exact numbers are illustrative; what's load-bearing is the transition — what each move buys, and what it costs in code complexity.
The operator¶
Row-wise softmax with the numerical-stability trick:
Three logical passes over the row:
- max pass — read \(x_0 \dots x_{V-1}\), find \(m\).
- sum-exp pass — read again, compute \(s = \sum \exp(x_k - m)\).
- normalize pass — read a third time, write \(y_i = \exp(x_i - m) / s\).
In NumPy this is one line:
The compiler may or may not fuse the three passes. On GPU, we fuse them — explicitly — and that's the optimization.
Working-set arithmetic¶
Per row (fp32, \(V = 600\)):
- Bytes read: \(V \cdot 4 = 2400\) B
- Bytes written: \(V \cdot 4 = 2400\) B
- FLOPs: ~\(5V = 3000\) FLOPs (one max, \(V\) exps, one sum-recip, \(V\) multiplies)
- Intensity: \(5V / 4V = 1.25\) FLOPs/byte → memory-bound on any GPU (\(I_\text{crit} \geq 4\) for fp64; \(\geq 156\) for fp16 Tensor Cores).
The "theoretical best" is whatever fraction of HBM bandwidth this kernel sustains. For \(V = 600\) fp32 and \(B\) in the hundreds, that's our ceiling.
Version 1: Naive (one thread per element)¶
__global__ void softmax_naive(const float* x, float* y, int V) {
int row = blockIdx.x;
int col = threadIdx.x;
// Pass 1: max (every thread computes the same max — wasteful)
float m = -INFINITY;
for (int k = 0; k < V; ++k) m = fmaxf(m, x[row * V + k]);
// Pass 2: sum-exp
float s = 0.0f;
for (int k = 0; k < V; ++k) s += expf(x[row * V + k] - m);
// Pass 3: write
if (col < V) y[row * V + col] = expf(x[row * V + col] - m) / s;
}
What's wrong:
- Every thread re-reads the row three times. 3× memory traffic.
- Every thread computes the max and the sum independently. \(V\)× redundant compute.
- Threads beyond \(V\) in the block sit idle. Wasted launch.
- No SMEM. The row is read from HBM each pass.
But it works. Run it, confirm correctness against np.softmax to 1e-5, then optimize. Never tune a wrong kernel.
Expected: hits maybe 1–3% of HBM bandwidth. Plenty of room to improve.
Version 2: Coalesced + SMEM¶
Step 1: load the row into SMEM once, with one thread per element, coalesced.
Step 2: reduce in SMEM for max and sum (the tree-reduction pattern from theory/01).
Step 3: each thread writes its output.
__global__ void softmax_smem(const float* x, float* y, int V) {
extern __shared__ float row[];
int r = blockIdx.x;
int t = threadIdx.x;
// 1. Coalesced load.
for (int k = t; k < V; k += blockDim.x) row[k] = x[r * V + k];
__syncthreads();
// 2. Reduce for max (tree reduction in SMEM).
// ... (max-reduce omitted; see lab 02)
__shared__ float m;
if (t == 0) {
float mm = -INFINITY;
for (int k = 0; k < V; ++k) mm = fmaxf(mm, row[k]);
m = mm;
}
__syncthreads();
// 3. Compute exp in place, reduce for sum.
for (int k = t; k < V; k += blockDim.x) row[k] = expf(row[k] - m);
__syncthreads();
__shared__ float s;
if (t == 0) {
float ss = 0.0f;
for (int k = 0; k < V; ++k) ss += row[k];
s = ss;
}
__syncthreads();
// 4. Normalize and write (coalesced).
for (int k = t; k < V; k += blockDim.x) y[r * V + k] = row[k] / s;
}
What changed:
- One coalesced read from HBM into SMEM.
- One coalesced write back to HBM.
- SMEM holds the row; subsequent passes hit SMEM at ~10× HBM bandwidth.
- Still has thread-0 serial reductions (the
if (t == 0)blocks) — easy to fix next.
Expected: ~10–20% of peak HBM bandwidth. A clean step up.
Version 3: Parallel reduction + online-softmax (one pass)¶
Two more moves:
- Replace the thread-0 serial reductions with tree reductions across the block (
__syncthreads()in a halving loop). All threads participate. - Fuse max and sum into one pass with the online-softmax recurrence: maintain a running max \(m\) and a running sum-of-exps \(s\), update both when a new element exceeds \(m\):
This is the same trick Flash-Attention uses (Phase 27). Two HBM reads collapse to one.
Result: a single pass over the row, then a normalize-and-write pass — one fewer HBM round trip.
Expected: ~30–50% of peak HBM bandwidth on \(V = 600\). This is the version that hits the ≥30%-of-cuBLAS target.
Version 4: Triton (lab 03)¶
Same algorithm, in Triton. The Python source is ~30 lines (compared to ~80 for the tuned CUDA C). Triton's autotuner sweeps block sizes; you specify the algorithm, the autotuner finds the params.
Expected on \(V = 600\): 80–95% of the hand-tuned CUDA C version. The remaining 5–20% is the cost of generality.
The ladder, summarized¶
| Version | HBM passes | Reductions | % of peak | Lines of code |
|---|---|---|---|---|
| Naive | 3 reads × \(V\) threads = \(3V\) effective | None (serial in each thread) | 1–3% | ~10 |
| Coalesced + SMEM | 1 read + 1 write | Serial in thread 0 | 10–20% | ~25 |
| Parallel + online | 1 read + 1 write | Tree in block | 30–50% | ~50 |
| Triton (autotuned) | Same | Same | 25–45% | ~30 (Python) |
The CUDA C tuned version reaches ≥30% of F.softmax. Triton lands close behind. Both go on the roofline plot.
What each move did, in roofline language¶
- Coalescing: raised attainable bandwidth (more bytes per memory transaction).
- SMEM caching: eliminated redundant HBM traffic, raising effective intensity.
- Parallel reduction: removed scheduler stalls (no
if (tid == 0)serial section). - Online softmax (fusing passes): eliminated one HBM round trip — directly halved bytes-moved per row.
Every one of these is a Phase-23 concept made concrete. That's the point of the phase.
Drill problems¶
- For \(V = 600\) fp32, \(B = 1024\), what's the per-row HBM read+write in bytes? Total HBM traffic for the batch? At 1.55 TB/s (A100 SXM4 40GB), how long is that just for the bytes?
- What if \(V = 8192\) (a real-vocab transformer)? Does the SMEM strategy still fit (SMEM per block is ~100 KB max)? If not, what changes?
- Why does online-softmax help fp16 more than fp32? (Hint: HBM bandwidth is per-byte; halved bytes → halved time. The fp16 compute peak was already plenty.)
- Where would Tensor Cores help in this kernel? (Hint: they wouldn't — softmax is element-wise + reductions, no matmul. Tensor Cores would help if we fused softmax with the LM-head GEMM — that's a Phase-27 topic.)
CPU-fallback note (for local development)¶
Borja's local machine has no CUDA. The CPU fallback path is just the NumPy code shown in §1; the dispatcher in src/minikernel/dispatch.py (lab 02) decides between the CUDA C kernel, the Triton kernel, and np.exp(x - x.max(...)). Numerical-equivalence tests run locally against the NumPy reference; performance tests run only on the cloud GPU.
This keeps iteration cheap — algorithm-level bugs surface on CPU in milliseconds. The cloud GPU is only spun up when the algorithm is correct.
What you should now be able to do¶
- Sketch the four versions of the softmax kernel from memory; explain what each move costs and what it buys.
- Predict roughly where each dot lands on the roofline.
- Apply the same ladder to a new operator (e.g., layernorm, RMSNorm) — the moves are reusable.
- Decide whether a kernel's bottleneck is HBM bandwidth, SMEM bank conflicts, register pressure, or scheduler stalls.
What this page does NOT cover¶
- GEMM tiling specifics (cuBLAS-style). A GEMM kernel has a different optimization ladder (register-tile + SMEM-tile + Tensor Core). Phase 24's alternative-kernel path; default kernel here is softmax.
- Flash-Attention. Phase 27. The online-softmax trick mentioned here is one of Flash-Attention's building blocks but only one.
- Bank-conflict avoidance. Briefly relevant for SMEM reductions; lab 02 demonstrates with a profile.
- Autotuning spaces. Triton's autotune surface is
theory/03.
Next: theory/03-triton.md — the Python-DSL kernel language, how its autotune works, when it beats CUDA C and when it doesn't.