Skip to content

English · Español

Lab 03 — Port a Triton + MiniGPT en PyTorch (aterriza el framework)

Objetivo: reescribir la softmax fusionada en Triton (~30 líneas), autotunearla, situar su punto en el roofline junto a las versiones en CUDA C. Después — aterriza el framework — portar el MiniGPT de la Fase 17 a PyTorch (torch_minigpt.py), verificar equivalencia byte a byte en fp32 y enchufar el kernel Triton en la softmax de la cabeza LM.

Tiempo estimado: 4–8 horas (división: 2–3 h Triton, 2–4 h port a PyTorch + integración).

Prerrequisito: lab/02-tuned-kernel.md completo. Kernel fused en CUDA C alcanzando ≥30% de F.softmax. Triton instalado (uv pip install triton).


Lo que produces

Dos artefactos más actualizaciones en src/minimodel/:

  • src/minikernel/softmax_triton.py — kernel Triton + bloque de autotune.
  • src/minimodel/torch_minigpt.py — port a PyTorch del MiniGPT de la Fase 17 (NumPy → torch.nn.Module).
  • tests/test_torch_minigpt.py — equivalencia byte a byte respecto a la versión NumPy en fp32.
  • experiments/24-triton-and-pytorch/bench.py — añade el punto de Triton al gráfico del roofline.
  • experiments/24-triton-and-pytorch/roofline.png — cuatro puntos: naive, smem, fused, triton + línea de referencia F.softmax.
  • experiments/24-triton-and-pytorch/manifest.json.
  • experiments/24-triton-and-pytorch/README.md — interpretación: dónde aterriza Triton respecto a CUDA, cómo fue el port, la experiencia del intercambio de kernel.

TODOs

Bloque A — softmax en Triton

  • Según theory/03: escribe softmax_kernel con @triton.jit y @triton.autotune sobre BLOCK ∈ {256, 512, 1024, 2048} y num_warps ∈ {2, 4, 8}.
  • Envuelve con una función Python softmax(x) que maneje entradas (B, V) lanzando (B,) programas.
  • Testea contra la referencia NumPy y contra la versión tuneada en CUDA C (ambas deben concordar a 1e-4).
  • Ejecuta autotune una vez con unas pocas formas representativas ((64, 600), (512, 600), (4096, 600)); cachea la configuración elegida.

Bloque B — bench de Triton, dibujar roofline de cuatro puntos

  • Añade Triton al bench.py del lab 02. Cronometra 100 lanzamientos (tras 3 warm-ups + el sweep de autotune).
  • Calcula la fracción del pico HBM.
  • Genera roofline.png: x = intensidad (FLOPs/byte), y = TFLOPS, con pendiente HBM + techos de cómputo (según phase-23/theory/04). Sitúa los puntos naive, smem, fused, triton, F.softmax.
  • Esperado: triton 80–95% del fused de CUDA C, ambos por debajo de F.softmax (que fusiona operaciones aguas arriba que puede; nosotros no hacemos esa fusion).

Bloque C — port a PyTorch de MiniGPT

Este es el primer código PyTorch del codebase.

  • src/minimodel/torch_minigpt.py: define GrammarMiniGPT(nn.Module) con los mismos conteos de capa que el MiniGPT gramatical de §A13 — L = 4 bloques, H = 4 cabezas, d = 64, d_h = 16, V ≈ 600. Submódulos: nn.Embedding, bloques de attention (usa nn.MultiheadAttention o un nn.Linear+softmax desde cero por transparencia), nn.LayerNorm, FFN (nn.Linear × 2 + GeLU), cabeza LM nn.Linear(d, V).
  • Carga los pesos del MiniGPT NumPy de la Fase 17. Mapea cada np.ndarray a un torch.Tensor (misma forma, fp32). Verifica que el data_ptr del weight de cada capa muestra la forma correcta tras la carga.
  • Test (tests/test_torch_minigpt.py): genera una entrada aleatoria x de forma (2, 16) (ids de token), semilla 42. Ejecuta ambos modelos NumPy y PyTorch en modo eval(). Aserta np.allclose(y_np, y_pt.numpy(), atol=1e-5, rtol=1e-5). Equivalente byte a byte en fp32 CPU.

Bloque D — enchufar la softmax Triton en la cabeza LM

  • En torch_minigpt.py, la capa final hace lógicamente F.softmax(lm_head(x), dim=-1). Sustituye F.softmax por triton_softmax(...) al ejecutar en CUDA (controla con if x.is_cuda).
  • En CPU (portátil de Borja): cae de vuelta a F.softmax. Los tests siguen pasando.
  • En GPU en la nube: usa el kernel Triton. Los logits generados concuerdan con la ruta CPU a 1e-3.
  • Verifica que una pasada hacia delante por el MiniGPT gramatical con el kernel custom produce la distribución de tokens esperada. Por ejemplo, alimenta "Yesterday I" tokenizado; los top-5 logits deben seguir siendo formas de verbo en past simple (según el entrenamiento del modelo de la Fase 17). El intercambio de kernel no debería cambiar la predicción.

Bloque E — manifest

{
  "experiment": "24-triton-and-pytorch",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "gpu": {"model": null, "compute_capability": null},
  "versions": {"python": "3.11.x", "torch": null, "triton": null, "cupy": null},
  "softmax_kernels": {
    "naive":  {"us_at_B512": null, "frac_of_F_softmax": null},
    "smem":   {"us_at_B512": null, "frac_of_F_softmax": null},
    "fused":  {"us_at_B512": null, "frac_of_F_softmax": null},
    "triton": {"us_at_B512": null, "frac_of_F_softmax": null, "autotune_picked": {"BLOCK": null, "num_warps": null}}
  },
  "torch_port": {
    "byte_equivalence_at_fp32_cpu": "passed | failed",
    "max_abs_diff_to_numpy_reference": null,
    "kernel_swap_changes_top1_token": null
  }
}

Restricciones

  • Solo PyTorch aquí, solo en este lab. No portes retroactivamente el código de las Fases 1–22.
  • Port fiel, no rediseño. Capa a capa; mismas numéricas. La Fase 25 puede rediseñar.
  • Equivalencia byte a byte en fp32 CPU es el contrato. fp32 en CUDA puede derivar a 1e-5; fp16 aún más. Documenta tolerancias.
  • El kernel custom no debe cambiar las predicciones aguas abajo (top-1 token sin cambios para la misma entrada). Si las cambia, hay un bug en el kernel o en el intercambio.

Condiciones de parada

Hecho cuando:

  1. El kernel Triton pasa corrección; punto en el roofline.
  2. El MiniGPT en PyTorch es equivalente byte a byte al MiniGPT en NumPy en fp32 CPU (atol=1e-5).
  3. Kernel Triton custom enchufado en el modelo PyTorch; la pasada hacia delante en CUDA produce top-1 = top-1 de la ruta CPU para el prompt demo de §A13 "Yesterday I".
  4. roofline.png commiteado.
  5. manifest.json commiteado.
  6. learners/borja/profile.md actualizado: "PyTorch interiorizado en la Fase 24" — ítem DoD obligatorio.

Escollos

  • Desajuste de orden de pesos. La convención (out, in) de NumPy difiere del nn.Linear.weight de PyTorch, que es (out, in). Coinciden — pero si por accidente escribiste (in, out) en algún sitio en la Fase 17, el port parece funcionar pero produce basura. Diagnostica revisando la salida de cada capa de forma independiente.
  • Desajuste numérico con F.softmax. La softmax de PyTorch puede usar un orden de reducción distinto al de NumPy. fp32 CPU debería seguir concordando a 1e-7 (ambos son deterministas en monohilo), pero llamadas BLAS multihilo de PyTorch pueden derivar. Testea con torch.set_num_threads(1).
  • Envenenamiento de la cache de autotune de Triton. Resultados viejos de autotune en ~/.triton/cache/ persisten entre runs. Si cambias la firma del kernel, la cache puede servir código rancio. triton.runtime.cache.clear() o borra el directorio.
  • Olvidar model.eval(). Modelos con Dropout/BatchNorm se comportan distinto en train vs eval. El MiniGPT gramatical no usa ninguno de los dos, pero pon siempre eval() para las comparaciones de inferencia.

Cuándo consultar solutions/

Tras cumplir todas las condiciones de parada. La referencia muestra el kernel Triton canónico, la tabla de mapeo del port NumPy→PyTorch y el glue code del intercambio de kernel.


Siguiente: PHASE_24_REPORT.md. La fase cierra con el roofline de cuatro puntos como gráfico cabecera.