English · Español
Lab 03 — Triton Port + PyTorch MiniGPT (the Framework Lands)¶
Goal: rewrite the fused softmax in Triton (~30 lines), autotune it, place its dot on the roofline alongside the CUDA C versions. Then — the framework lands — port Phase-17 MiniGPT to PyTorch (
torch_minigpt.py), verify byte-equivalence at fp32, and slot the Triton kernel into the LM-head softmax.Estimated time: 4–8 hours (split: 2–3 h Triton, 2–4 h PyTorch port + integration).
Prereq:
lab/02-tuned-kernel.mdcomplete. fused CUDA C kernel hitting ≥30% ofF.softmax. Triton installed (uv pip install triton).
What you produce¶
Two artifacts plus updates to src/minimodel/:
src/minikernel/softmax_triton.py— Triton kernel + autotune block.src/minimodel/torch_minigpt.py— PyTorch port of the Phase-17 MiniGPT (NumPy → torch.nn.Module).tests/test_torch_minigpt.py— byte-equivalence to the NumPy version at fp32.experiments/24-triton-and-pytorch/bench.py— adds Triton dot to the roofline plot.experiments/24-triton-and-pytorch/roofline.png— four dots: naive, smem, fused, triton +F.softmaxreference line.experiments/24-triton-and-pytorch/manifest.json.experiments/24-triton-and-pytorch/README.md— interpretation: where Triton lands relative to CUDA, how the port worked, the kernel-swap experience.
TODOs¶
Block A — Triton softmax¶
- Per
theory/03: writesoftmax_kernelwith@triton.jitand@triton.autotuneoverBLOCK ∈ {256, 512, 1024, 2048}andnum_warps ∈ {2, 4, 8}. - Wrap with a
softmax(x)Python function that handles(B, V)inputs by launching(B,)programs. - Test against NumPy reference and against the CUDA C tuned version (both should agree to 1e-4).
- Run autotune once with a few representative shapes (
(64, 600),(512, 600),(4096, 600)); cache the chosen config.
Block B — bench Triton, plot four-dot roofline¶
- Add Triton to the
bench.pyfrom lab 02. Time 100 launches (after 3 warm-ups + the autotune sweep). - Compute fraction of HBM peak.
- Generate
roofline.png: x = intensity (FLOPs/byte), y = TFLOPS, with HBM slope + compute ceilings (perphase-23/theory/04). Place naive, smem, fused, triton,F.softmaxdots. - Expected: triton 80–95% of fused CUDA C, both below
F.softmax(which fuses upstream ops it can; we're not doing that fusion).
Block C — PyTorch port of MiniGPT¶
This is the first PyTorch code in the codebase.
-
src/minimodel/torch_minigpt.py: defineGrammarMiniGPT(nn.Module)with the same layer counts as the §A13 grammar MiniGPT —L = 4blocks,H = 4heads,d = 64,d_h = 16,V ≈ 600. Submodules:nn.Embedding, attention blocks (usenn.MultiheadAttentionor a from-scratchnn.Linear+softmax for transparency),nn.LayerNorm, FFN (nn.Linear× 2 + GeLU),nn.Linear(d, V)LM head. - Load weights from the Phase-17 NumPy MiniGPT. Map each
np.ndarrayto atorch.Tensor(same shape, fp32). Verify each layer's weightdata_ptrshows correct shape after load. - Test (
tests/test_torch_minigpt.py): generate a random inputxof shape(2, 16)(token ids), seed 42. Run both NumPy and PyTorch models ineval()mode. Assertnp.allclose(y_np, y_pt.numpy(), atol=1e-5, rtol=1e-5). Byte-equivalent at fp32 CPU.
Block D — slot the Triton softmax into the LM head¶
- In
torch_minigpt.py, the final layer logically doesF.softmax(lm_head(x), dim=-1). Replace theF.softmaxwithtriton_softmax(...)when running on CUDA (gate behindif x.is_cuda). - On CPU (Borja's laptop): falls back to
F.softmax. Tests still pass. - On cloud GPU: uses Triton kernel. Generated logits agree with the CPU path to 1e-3.
- Verify a forward pass through grammar MiniGPT with the custom kernel produces the expected token distribution. E.g., feed
"Yesterday I"-tokenized input; the top-5 logits should still be past-simple verb forms (per the Phase-17 model's training). The kernel swap shouldn't change the prediction.
Block E — manifest¶
{
"experiment": "24-triton-and-pytorch",
"date": "YYYY-MM-DD",
"seed": 42,
"gpu": {"model": null, "compute_capability": null},
"versions": {"python": "3.11.x", "torch": null, "triton": null, "cupy": null},
"softmax_kernels": {
"naive": {"us_at_B512": null, "frac_of_F_softmax": null},
"smem": {"us_at_B512": null, "frac_of_F_softmax": null},
"fused": {"us_at_B512": null, "frac_of_F_softmax": null},
"triton": {"us_at_B512": null, "frac_of_F_softmax": null, "autotune_picked": {"BLOCK": null, "num_warps": null}}
},
"torch_port": {
"byte_equivalence_at_fp32_cpu": "passed | failed",
"max_abs_diff_to_numpy_reference": null,
"kernel_swap_changes_top1_token": null
}
}
Constraints¶
- PyTorch only here, only this lab. Don't retroactively port Phase 1–22 code.
- Faithful port, not a redesign. Layer-for-layer; same numerics. Phase 25 may redesign.
- fp32 byte-equivalence on CPU is the contract. fp32 on CUDA might drift to 1e-5; fp16 even more. Document tolerances.
- The custom kernel must not change downstream predictions (top-1 token unchanged for the same input). If it does, there's a bug in the kernel or the swap.
Stop conditions¶
Done when:
- Triton kernel passes correctness; dot on roofline.
- PyTorch MiniGPT byte-equivalent to NumPy MiniGPT at fp32 CPU (
atol=1e-5). - Custom Triton kernel slotted into PyTorch model; forward pass on CUDA produces top-1 = top-1 of CPU path for the §A13 demo prompt
"Yesterday I". roofline.pngcommitted.manifest.jsoncommitted.learners/borja/profile.mdupdated: "PyTorch internalized at Phase 24" — required DoD item.
Pitfalls¶
- Weight ordering mismatch. NumPy's
(out, in)convention differs from PyTorch'snn.Linear.weightwhich is(out, in). They match — but if you accidentally wrote(in, out)somewhere in Phase 17, the port appears to work but produces garbage. Diagnose by checking each layer's output independently. F.softmaxnumerical mismatch. PyTorch's softmax may use a different reduction order than NumPy's. fp32 CPU should still agree to 1e-7 (both are deterministic single-threaded), but multi-threaded PyTorch BLAS calls can drift. Test withtorch.set_num_threads(1).- Triton autotune cache poisoning. Old autotune results in
~/.triton/cache/persist across runs. If you change the kernel signature, the cache may serve stale code.triton.runtime.cache.clear()or delete the dir. model.eval()forgotten. Dropout/BatchNorm-layered models behave differently in train vs eval. The grammar MiniGPT doesn't use either, but always seteval()for inference comparisons regardless.
When to consult solutions/¶
After all stop conditions met. The reference shows the canonical Triton kernel, the NumPy→PyTorch port mapping table, and the kernel-swap glue code.
Next: PHASE_24_REPORT.md. The phase closes with the four-dot roofline as the headline plot.