Skip to content

English · Español

Lab 03 — Triton Port + PyTorch MiniGPT (the Framework Lands)

Goal: rewrite the fused softmax in Triton (~30 lines), autotune it, place its dot on the roofline alongside the CUDA C versions. Then — the framework lands — port Phase-17 MiniGPT to PyTorch (torch_minigpt.py), verify byte-equivalence at fp32, and slot the Triton kernel into the LM-head softmax.

Estimated time: 4–8 hours (split: 2–3 h Triton, 2–4 h PyTorch port + integration).

Prereq: lab/02-tuned-kernel.md complete. fused CUDA C kernel hitting ≥30% of F.softmax. Triton installed (uv pip install triton).


What you produce

Two artifacts plus updates to src/minimodel/:

  • src/minikernel/softmax_triton.py — Triton kernel + autotune block.
  • src/minimodel/torch_minigpt.py — PyTorch port of the Phase-17 MiniGPT (NumPy → torch.nn.Module).
  • tests/test_torch_minigpt.py — byte-equivalence to the NumPy version at fp32.
  • experiments/24-triton-and-pytorch/bench.py — adds Triton dot to the roofline plot.
  • experiments/24-triton-and-pytorch/roofline.png — four dots: naive, smem, fused, triton + F.softmax reference line.
  • experiments/24-triton-and-pytorch/manifest.json.
  • experiments/24-triton-and-pytorch/README.md — interpretation: where Triton lands relative to CUDA, how the port worked, the kernel-swap experience.

TODOs

Block A — Triton softmax

  • Per theory/03: write softmax_kernel with @triton.jit and @triton.autotune over BLOCK ∈ {256, 512, 1024, 2048} and num_warps ∈ {2, 4, 8}.
  • Wrap with a softmax(x) Python function that handles (B, V) inputs by launching (B,) programs.
  • Test against NumPy reference and against the CUDA C tuned version (both should agree to 1e-4).
  • Run autotune once with a few representative shapes ((64, 600), (512, 600), (4096, 600)); cache the chosen config.

Block B — bench Triton, plot four-dot roofline

  • Add Triton to the bench.py from lab 02. Time 100 launches (after 3 warm-ups + the autotune sweep).
  • Compute fraction of HBM peak.
  • Generate roofline.png: x = intensity (FLOPs/byte), y = TFLOPS, with HBM slope + compute ceilings (per phase-23/theory/04). Place naive, smem, fused, triton, F.softmax dots.
  • Expected: triton 80–95% of fused CUDA C, both below F.softmax (which fuses upstream ops it can; we're not doing that fusion).

Block C — PyTorch port of MiniGPT

This is the first PyTorch code in the codebase.

  • src/minimodel/torch_minigpt.py: define GrammarMiniGPT(nn.Module) with the same layer counts as the §A13 grammar MiniGPT — L = 4 blocks, H = 4 heads, d = 64, d_h = 16, V ≈ 600. Submodules: nn.Embedding, attention blocks (use nn.MultiheadAttention or a from-scratch nn.Linear+softmax for transparency), nn.LayerNorm, FFN (nn.Linear × 2 + GeLU), nn.Linear(d, V) LM head.
  • Load weights from the Phase-17 NumPy MiniGPT. Map each np.ndarray to a torch.Tensor (same shape, fp32). Verify each layer's weight data_ptr shows correct shape after load.
  • Test (tests/test_torch_minigpt.py): generate a random input x of shape (2, 16) (token ids), seed 42. Run both NumPy and PyTorch models in eval() mode. Assert np.allclose(y_np, y_pt.numpy(), atol=1e-5, rtol=1e-5). Byte-equivalent at fp32 CPU.

Block D — slot the Triton softmax into the LM head

  • In torch_minigpt.py, the final layer logically does F.softmax(lm_head(x), dim=-1). Replace the F.softmax with triton_softmax(...) when running on CUDA (gate behind if x.is_cuda).
  • On CPU (Borja's laptop): falls back to F.softmax. Tests still pass.
  • On cloud GPU: uses Triton kernel. Generated logits agree with the CPU path to 1e-3.
  • Verify a forward pass through grammar MiniGPT with the custom kernel produces the expected token distribution. E.g., feed "Yesterday I"-tokenized input; the top-5 logits should still be past-simple verb forms (per the Phase-17 model's training). The kernel swap shouldn't change the prediction.

Block E — manifest

{
  "experiment": "24-triton-and-pytorch",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "gpu": {"model": null, "compute_capability": null},
  "versions": {"python": "3.11.x", "torch": null, "triton": null, "cupy": null},
  "softmax_kernels": {
    "naive":  {"us_at_B512": null, "frac_of_F_softmax": null},
    "smem":   {"us_at_B512": null, "frac_of_F_softmax": null},
    "fused":  {"us_at_B512": null, "frac_of_F_softmax": null},
    "triton": {"us_at_B512": null, "frac_of_F_softmax": null, "autotune_picked": {"BLOCK": null, "num_warps": null}}
  },
  "torch_port": {
    "byte_equivalence_at_fp32_cpu": "passed | failed",
    "max_abs_diff_to_numpy_reference": null,
    "kernel_swap_changes_top1_token": null
  }
}

Constraints

  • PyTorch only here, only this lab. Don't retroactively port Phase 1–22 code.
  • Faithful port, not a redesign. Layer-for-layer; same numerics. Phase 25 may redesign.
  • fp32 byte-equivalence on CPU is the contract. fp32 on CUDA might drift to 1e-5; fp16 even more. Document tolerances.
  • The custom kernel must not change downstream predictions (top-1 token unchanged for the same input). If it does, there's a bug in the kernel or the swap.

Stop conditions

Done when:

  1. Triton kernel passes correctness; dot on roofline.
  2. PyTorch MiniGPT byte-equivalent to NumPy MiniGPT at fp32 CPU (atol=1e-5).
  3. Custom Triton kernel slotted into PyTorch model; forward pass on CUDA produces top-1 = top-1 of CPU path for the §A13 demo prompt "Yesterday I".
  4. roofline.png committed.
  5. manifest.json committed.
  6. learners/borja/profile.md updated: "PyTorch internalized at Phase 24" — required DoD item.

Pitfalls

  • Weight ordering mismatch. NumPy's (out, in) convention differs from PyTorch's nn.Linear.weight which is (out, in). They match — but if you accidentally wrote (in, out) somewhere in Phase 17, the port appears to work but produces garbage. Diagnose by checking each layer's output independently.
  • F.softmax numerical mismatch. PyTorch's softmax may use a different reduction order than NumPy's. fp32 CPU should still agree to 1e-7 (both are deterministic single-threaded), but multi-threaded PyTorch BLAS calls can drift. Test with torch.set_num_threads(1).
  • Triton autotune cache poisoning. Old autotune results in ~/.triton/cache/ persist across runs. If you change the kernel signature, the cache may serve stale code. triton.runtime.cache.clear() or delete the dir.
  • model.eval() forgotten. Dropout/BatchNorm-layered models behave differently in train vs eval. The grammar MiniGPT doesn't use either, but always set eval() for inference comparisons regardless.

When to consult solutions/

After all stop conditions met. The reference shows the canonical Triton kernel, the NumPy→PyTorch port mapping table, and the kernel-swap glue code.


Next: PHASE_24_REPORT.md. The phase closes with the four-dot roofline as the headline plot.