English · Español
Lab 02 — Tuned Fused-Softmax Kernel (≥30% of cuBLAS)¶
Goal: apply the optimization ladder from
theory/02to the naive kernel from lab 01. Climb from ~1% of peak HBM to ≥30% oftorch.nn.functional.softmaxperformance at \(B = 512, V = 600\). Capture anncuprofile of the final version. This is the lab where Phase 24's DoD perf target is met.Estimated time: 6–10 hours (kernel tuning is iterative).
Prereq:
lab/01-naive-kernel.mdcomplete (correct naive baseline exists).nsight-compute(ncu) installed on cloud GPU.
What you produce¶
A directory experiments/24-tuned-kernel/ and updated src/minikernel/:
src/minikernel/softmax_smem.cu— coalesced + SMEM version.src/minikernel/softmax_fused.cu— parallel-reduce + online-softmax (the ≥30%-of-cuBLAS version).src/minikernel/softmax.py— public-facingsoftmax(x)that dispatches to fused → smem → naive → numpy in order of availability.tests/test_softmax_tuned.py— equivalence tests for both new kernels.experiments/24-tuned-kernel/bench.py— head-to-head: naive vs smem vs fused vsF.softmax, across \(B\) sweep.experiments/24-tuned-kernel/ncu_report.txt— annotatedncuprofile of the fused version.experiments/24-tuned-kernel/manifest.json.experiments/24-tuned-kernel/README.md— 3 paragraphs: what each move bought, thencuinterpretation, the residual gap to cuBLAS /F.softmax.
TODOs¶
Block A — version 2 (coalesced + SMEM)¶
- Per
theory/02§"Version 2": one block per row, threads cooperatively load the row intoextern __shared__ float row[], sync, then reduce serially in thread 0 (kept simple). Write back coalesced. - Choose block size: \(\geq V\) rounded up to power of 2 (so 1024). Launch with
<<<B, 1024, V * sizeof(float)>>>. - Test correctness as in lab 01.
- Bench: expect 10–20% of peak HBM. Document the jump from naive.
Block B — version 3 (parallel reduction + online softmax)¶
- Replace the
if (tid == 0)serial max/sum with tree reductions across the block (the canonical pattern fromphase-23/theory/03). - Fuse the max-pass and sum-pass using online softmax (theory/02 §"Version 3"). One pass over the row reads & accumulates both \(m\) and \(s\).
- Test correctness — note: online softmax can drift numerically vs the 3-pass at fp32; tolerance may need to be 1e-4 instead of 1e-5. Document.
- Bench: expect 30–50% of peak HBM. This is the DoD-relevant version.
Block C — ncu profile¶
- On the cloud GPU:
ncu --set full --section MemoryWorkloadAnalysis --section ComputeWorkloadAnalysis --section Occupancy --target-processes all python bench.py. Save report toncu_report.ncu-repand a text export toncu_report.txt. - Read the report. Identify:
- Achieved occupancy vs theoretical (from compute capability).
- HBM throughput vs peak (from device spec).
- L1 / SMEM hit rate.
- Stall reasons (Memory Throttle, Memory Dependency, Execution Dependency, etc.).
- Write 1-paragraph annotation in
README.md. Identify the dominant stall reason. If it's not "Memory Throttle" (i.e., not HBM-bound), something is wrong (the kernel is supposed to be memory-bound).
Block D — compare to cuBLAS / F.softmax¶
- In
bench.py, also timetorch.nn.functional.softmax(x, dim=-1)at the same \((B, V)\). PyTorch dispatches tocuDNN's softmax or the JIT inductor (depending on torch version + warm-up). - Compute:
tuned_kernel_time / F_softmax_time. Target: ≤ 3.0 (i.e., your kernel is at least ⅓ the speed). If you hit ≤ 1.5, great — sometimes a fused custom kernel beats a generic one at small \(V\). - Document the gap. Don't grind to close it; understand it.
Block E — manifest¶
{
"experiment": "24-tuned-kernel",
"date": "YYYY-MM-DD",
"seed": 42,
"gpu": {"model": null, "compute_capability": null, "hbm_peak_gbs": null},
"versions": {"python": "3.11.x", "cupy": null, "torch": null, "ncu": null},
"kernels": {
"naive": {"median_us_at_B512": null, "fraction_of_peak": null},
"smem": {"median_us_at_B512": null, "fraction_of_peak": null},
"fused": {"median_us_at_B512": null, "fraction_of_peak": null},
"F_softmax_ref": {"median_us_at_B512": null}
},
"results_summary": {
"fused_vs_F_softmax_ratio": null,
"dod_30pct_met": null,
"dominant_stall_reason": null
}
}
Constraints¶
- Correctness first. Don't try Block B before Block A passes correctness.
- One change at a time. Going from naive → SMEM → fused → online is four moves. Bench between each — knowing what each move bought is the lesson.
- fp32 throughout. fp16 / bf16 is optional Block F (below); not required for DoD.
- Pin the seed. All measurements reproducible.
Optional Block F — fp16 path¶
- Repeat the fused kernel with fp16 inputs, fp32 accumulator. Tolerance vs
F.softmax(fp16): 1e-2. - Bench: 1.5–2× speedup (memory-bound — halved bytes ≈ halved time).
- If the dispatcher routes by dtype, add fp16 branch. If not, keep fp16 in a separate
softmax_fused_fp16.cu.
Stop conditions¶
Done when:
- SMEM version passes correctness; bench dot recorded.
- Fused (online + parallel) version passes correctness; bench dot recorded.
- Fused vs
F.softmaxat \(B=512, V=600\): ratio ≤ 3.0 (i.e., ≥33% ofF.softmaxperf — meets DoD). ncu_report.txtcommitted with annotation identifying dominant stall reason.manifest.jsoncommitted.README.mddocuments what each optimization move bought (with numbers).
Pitfalls¶
- Online-softmax numerical drift. The recurrence updates \(m\) and \(s\) in lockstep; ordering of operations matters. Standard implementation: see Milakov & Gimelshein 2018 ("Online Normalizer Calculation"). Match the paper's order exactly to match reference behavior.
- SMEM bank conflicts. 32 banks; access pattern
row[tid]with tid spanning 0..1023 maps to banktid % 32. For \(V = 600\) < 1024, the tail threads idle — no conflicts. For larger \(V\), padding may be needed. Not relevant at grammar scale. F.softmaxfaster than your kernel because it fuses with the GEMM upstream. PyTorch'sinductorsometimes fuseslm_head + softmaxinto one kernel. If you're comparing against an inductor-fused reference, you're comparing apples to oranges. Usetorch.nn.functional.softmaxdirectly with@torch.compile(mode='reduce-overhead')off.- Register spill from too-large block size.
--ptxas-options=-vreports register usage. If > 64 regs/thread and occupancy is low, lower block size.
When to consult solutions/¶
After all stop conditions met. The reference walks through the exact sequence of moves (and the numbers each hits on an A10).
Next lab: lab/03-triton-and-pytorch.md.