Skip to content

English · Español

Lab 02 — Tuned Fused-Softmax Kernel (≥30% of cuBLAS)

Goal: apply the optimization ladder from theory/02 to the naive kernel from lab 01. Climb from ~1% of peak HBM to ≥30% of torch.nn.functional.softmax performance at \(B = 512, V = 600\). Capture an ncu profile of the final version. This is the lab where Phase 24's DoD perf target is met.

Estimated time: 6–10 hours (kernel tuning is iterative).

Prereq: lab/01-naive-kernel.md complete (correct naive baseline exists). nsight-compute (ncu) installed on cloud GPU.


What you produce

A directory experiments/24-tuned-kernel/ and updated src/minikernel/:

  • src/minikernel/softmax_smem.cu — coalesced + SMEM version.
  • src/minikernel/softmax_fused.cu — parallel-reduce + online-softmax (the ≥30%-of-cuBLAS version).
  • src/minikernel/softmax.py — public-facing softmax(x) that dispatches to fused → smem → naive → numpy in order of availability.
  • tests/test_softmax_tuned.py — equivalence tests for both new kernels.
  • experiments/24-tuned-kernel/bench.py — head-to-head: naive vs smem vs fused vs F.softmax, across \(B\) sweep.
  • experiments/24-tuned-kernel/ncu_report.txt — annotated ncu profile of the fused version.
  • experiments/24-tuned-kernel/manifest.json.
  • experiments/24-tuned-kernel/README.md — 3 paragraphs: what each move bought, the ncu interpretation, the residual gap to cuBLAS / F.softmax.

TODOs

Block A — version 2 (coalesced + SMEM)

  • Per theory/02 §"Version 2": one block per row, threads cooperatively load the row into extern __shared__ float row[], sync, then reduce serially in thread 0 (kept simple). Write back coalesced.
  • Choose block size: \(\geq V\) rounded up to power of 2 (so 1024). Launch with <<<B, 1024, V * sizeof(float)>>>.
  • Test correctness as in lab 01.
  • Bench: expect 10–20% of peak HBM. Document the jump from naive.

Block B — version 3 (parallel reduction + online softmax)

  • Replace the if (tid == 0) serial max/sum with tree reductions across the block (the canonical pattern from phase-23/theory/03).
  • Fuse the max-pass and sum-pass using online softmax (theory/02 §"Version 3"). One pass over the row reads & accumulates both \(m\) and \(s\).
  • Test correctness — note: online softmax can drift numerically vs the 3-pass at fp32; tolerance may need to be 1e-4 instead of 1e-5. Document.
  • Bench: expect 30–50% of peak HBM. This is the DoD-relevant version.

Block C — ncu profile

  • On the cloud GPU: ncu --set full --section MemoryWorkloadAnalysis --section ComputeWorkloadAnalysis --section Occupancy --target-processes all python bench.py. Save report to ncu_report.ncu-rep and a text export to ncu_report.txt.
  • Read the report. Identify:
  • Achieved occupancy vs theoretical (from compute capability).
  • HBM throughput vs peak (from device spec).
  • L1 / SMEM hit rate.
  • Stall reasons (Memory Throttle, Memory Dependency, Execution Dependency, etc.).
  • Write 1-paragraph annotation in README.md. Identify the dominant stall reason. If it's not "Memory Throttle" (i.e., not HBM-bound), something is wrong (the kernel is supposed to be memory-bound).

Block D — compare to cuBLAS / F.softmax

  • In bench.py, also time torch.nn.functional.softmax(x, dim=-1) at the same \((B, V)\). PyTorch dispatches to cuDNN's softmax or the JIT inductor (depending on torch version + warm-up).
  • Compute: tuned_kernel_time / F_softmax_time. Target: ≤ 3.0 (i.e., your kernel is at least ⅓ the speed). If you hit ≤ 1.5, great — sometimes a fused custom kernel beats a generic one at small \(V\).
  • Document the gap. Don't grind to close it; understand it.

Block E — manifest

{
  "experiment": "24-tuned-kernel",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "gpu": {"model": null, "compute_capability": null, "hbm_peak_gbs": null},
  "versions": {"python": "3.11.x", "cupy": null, "torch": null, "ncu": null},
  "kernels": {
    "naive": {"median_us_at_B512": null, "fraction_of_peak": null},
    "smem":  {"median_us_at_B512": null, "fraction_of_peak": null},
    "fused": {"median_us_at_B512": null, "fraction_of_peak": null},
    "F_softmax_ref": {"median_us_at_B512": null}
  },
  "results_summary": {
    "fused_vs_F_softmax_ratio": null,
    "dod_30pct_met": null,
    "dominant_stall_reason": null
  }
}

Constraints

  • Correctness first. Don't try Block B before Block A passes correctness.
  • One change at a time. Going from naive → SMEM → fused → online is four moves. Bench between each — knowing what each move bought is the lesson.
  • fp32 throughout. fp16 / bf16 is optional Block F (below); not required for DoD.
  • Pin the seed. All measurements reproducible.

Optional Block F — fp16 path

  • Repeat the fused kernel with fp16 inputs, fp32 accumulator. Tolerance vs F.softmax(fp16): 1e-2.
  • Bench: 1.5–2× speedup (memory-bound — halved bytes ≈ halved time).
  • If the dispatcher routes by dtype, add fp16 branch. If not, keep fp16 in a separate softmax_fused_fp16.cu.

Stop conditions

Done when:

  1. SMEM version passes correctness; bench dot recorded.
  2. Fused (online + parallel) version passes correctness; bench dot recorded.
  3. Fused vs F.softmax at \(B=512, V=600\): ratio ≤ 3.0 (i.e., ≥33% of F.softmax perf — meets DoD).
  4. ncu_report.txt committed with annotation identifying dominant stall reason.
  5. manifest.json committed.
  6. README.md documents what each optimization move bought (with numbers).

Pitfalls

  • Online-softmax numerical drift. The recurrence updates \(m\) and \(s\) in lockstep; ordering of operations matters. Standard implementation: see Milakov & Gimelshein 2018 ("Online Normalizer Calculation"). Match the paper's order exactly to match reference behavior.
  • SMEM bank conflicts. 32 banks; access pattern row[tid] with tid spanning 0..1023 maps to bank tid % 32. For \(V = 600\) < 1024, the tail threads idle — no conflicts. For larger \(V\), padding may be needed. Not relevant at grammar scale.
  • F.softmax faster than your kernel because it fuses with the GEMM upstream. PyTorch's inductor sometimes fuses lm_head + softmax into one kernel. If you're comparing against an inductor-fused reference, you're comparing apples to oranges. Use torch.nn.functional.softmax directly with @torch.compile(mode='reduce-overhead') off.
  • Register spill from too-large block size. --ptxas-options=-v reports register usage. If > 64 regs/thread and occupancy is low, lower block size.

When to consult solutions/

After all stop conditions met. The reference walks through the exact sequence of moves (and the numbers each hits on an A10).


Next lab: lab/03-triton-and-pytorch.md.