Skip to content

English · Español

03 — Triton: CUDA para el caso del 80%

🇪🇸 Triton es un DSL de Python — un lenguaje que parece NumPy pero compila a PTX (lenguaje GPU). Su valor: el 80% de los kernels que escribirías en CUDA C los escribes en Triton con ~10× menos código, y el autotuner busca el block-size / num-warps / num-stages óptimo por ti. El 20% restante (flash attention, GEMM ultraoptimizado) sigue siendo terreno de CUDA C / CUTLASS.

Esta página presenta Triton — qué es, qué automatiza, qué no, y cómo leer su autotune. Al final puedes escribir el mismo kernel de softmax-sobre-vocabulario-gramatical de theory/02 en ~30 líneas de Triton y predecir aproximadamente dónde aterrizará en el roofline relativo a tu versión CUDA C tuneada.


Qué es Triton

Triton es un DSL tipo Python embebido en Python, desarrollado en OpenAI, compilado por un stack basado en MLIR hasta PTX (luego a SASS por ptxas). Escribes una función con @triton.jit, la decoras con configs de autotune y la llamas como una función Python normal. El compilador infiere vectorización, asignación de registros y scheduling básico de memoria; tú aportas el algoritmo y el espacio de tamaños de tile.

Un kernel Triton mínimo:

import triton
import triton.language as tl

@triton.jit
def softmax_kernel(x_ptr, y_ptr, V, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < V
    x = tl.load(x_ptr + row * V + cols, mask=mask, other=-float('inf'))
    m = tl.max(x, axis=0)
    e = tl.exp(x - m)
    s = tl.sum(e, axis=0)
    y = e / s
    tl.store(y_ptr + row * V + cols, y, mask=mask)

Eso es un kernel de fused softmax completo para una fila por "programa" (el nombre de Triton para un bloque de hilos). Para el vocab \(V \approx 600\) del MiniGPT gramatical, pondrías BLOCK=1024 (siguiente potencia de 2 ≥ 600), enmascarando la cola.

Comparado con la versión CUDA C tuneada de ~80 líneas, Triton oculta:

  • Coalescing — tl.load se coalesce automáticamente.
  • Asignación de SMEM — no escribes __shared__; el compilador decide.
  • Reducciones — tl.max y tl.sum son eficientes a nivel de warp y bloque.
  • Sincronización — sin __syncthreads().

Qué no oculta Triton:

  • Tamaño de bloque / programa. Tú eliges BLOCK.
  • Algoritmo. Online-softmax vs naive 3-pass es decisión tuya.
  • Layout de memoria (patrones de acceso row-major vs column-major).
  • Uso de registros. El uso intensivo de tl.where y aritmética compila a registros; hay spill si fuerzas demasiado.

El autotuner

La killer feature de Triton es el autotuner:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK': 256},  num_warps=4),
        triton.Config({'BLOCK': 512},  num_warps=4),
        triton.Config({'BLOCK': 1024}, num_warps=8),
        triton.Config({'BLOCK': 2048}, num_warps=8),
    ],
    key=['V'],  # re-tune when V changes
)
@triton.jit
def softmax_kernel(...): ...

Primera llamada con una V dada: el autotuner benchmarkea cada config, escoge el más rápido, cachea la elección. Las llamadas posteriores con la misma V hacen hit en el cache.

Lo que explora el autotuner:

  • BLOCK: tamaño de tile (típicamente potencias de 2).
  • num_warps: warps por bloque (1, 2, 4, 8, 16). Controla el presupuesto de occupancy.
  • num_stages: etapas del pipeline software (relevante para matmul, menos para softmax).

Lo que no explora el autotuner:

  • El algoritmo (online vs 3-pass).
  • El layout de memoria de las entradas.
  • Si usar Tensor Cores (Triton los usa cuando aplican para los matmuls tl.dot; para softmax, irrelevante).

Implicación: un mal algoritmo autotuneado sigue siendo un mal algoritmo. El autotune de Triton te salva de adivinar tamaños de tile, no de un mal diseño.

Cuándo gana Triton, cuándo gana CUDA C

Situación Elección
Kernel custom elementwise o de reducción (softmax, layernorm, RMSNorm, RoPE) Triton. Recorta el tiempo de desarrollo 10×.
GEMM-like con utilización ajustada de Tensor Cores cuBLAS / CUTLASS. El tl.dot de Triton es bueno pero no es estado del arte para las shapes más enrevesadas.
Kernel de investigación nuevo (Flash-Attention v3, MLA) CUDA C + CUTLASS. Triton puede aproximarlo pero el 10% superior de perf suele requerir control crudo.
Prototipo rápido para probar un algoritmo Triton. De calle.
Kernel de producción para la ruta más caliente CUDA C si Triton está dentro del 5%; Triton si está dentro del 15%. El coste de mantenimiento importa.

Para la softmax-sobre-vocab de la Fase 24: Triton es la herramienta adecuada. El lab 03 te hace escribirla tras la versión CUDA C tuneada del lab 02, específicamente para sentir el contraste.

Cómo compila Triton

El pipeline (no lo usas, pero conocerlo ayuda a depurar):

Triton DSL (Python AST)
   ↓ (Triton MLIR dialect)
ttir.mlir (Triton IR)
   ↓ (lowering passes: fuse, vectorize, schedule)
ttgir.mlir (Triton GPU IR)
   ↓ (PTX emission)
ptx
   ↓ (ptxas, NVIDIA's PTX→SASS assembler)
SASS / cubin

triton.compile(...) expone los IR intermedios. Si un kernel va lento, volcar ttgir te dice qué decidió el scheduler. La mayoría de usuarios de kernels nunca necesitan mirar la IR; el tuning avanzado sí.

Cross-check con ncu

Perfilar un kernel Triton usa las mismas herramientas ncu que CUDA C. Los reportes parecen idénticos — los mismos nombres de métrica (occupancy alcanzado, throughput HBM, conflictos de bancos en SMEM). Las diferencias:

  • Los nombres de kernel de Triton están mangleados — Triton añade un hash para desambiguar variantes del autotune. softmax_kernel_0d1d2c3c es normal.
  • Atribución a la fuente: ncu --source on muestra PTX; mapear de vuelta al DSL de Triton requiere la vista de fuente Triton de ncu (Triton 3.x+).

Para el entregable de la Fase 24: perfila tanto la versión CUDA C tuneada como la Triton autotuneada, coloca ambas en el mismo roofline, comenta la brecha.

Modo CPU: el interpreter de Triton para desarrollo local

Triton 3.x incluye un modo intérprete (triton.runtime.driver.set_active(...)) que corre funciones @triton.jit en CPU usando semántica NumPy. Lento pero útil para:

  • Chequeos de corrección sin GPU (la máquina de Borja).
  • Recorrer el algoritmo paso a paso en pdb.
  • CI sin runners con GPU.

Este no es el fallback de CPU de producción (src/minikernel/dispatch.py usa NumPy directamente — más simple, sin dependencia de Triton). Pero para desarrollo en el portátil de Borja, el modo intérprete de Triton es lo más cercano a "correr el código GPU localmente".

Una softmax Triton con autotune (completa)

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK': 256},  num_warps=2),
        triton.Config({'BLOCK': 512},  num_warps=4),
        triton.Config({'BLOCK': 1024}, num_warps=8),
    ],
    key=['V'],
)
@triton.jit
def softmax_kernel(x_ptr, y_ptr, V, BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < V
    x = tl.load(x_ptr + row * V + cols, mask=mask, other=-float('inf'))
    m = tl.max(x, axis=0)
    e = tl.exp(x - m)
    s = tl.sum(e, axis=0)
    y = e / s
    tl.store(y_ptr + row * V + cols, y, mask=mask)

def softmax(x):  # Python wrapper
    B, V = x.shape
    y = torch.empty_like(x)
    softmax_kernel[(B,)](x, y, V)
    return y

Eso es todo. El lab 03 lo construye incrementalmente.

Lo que deberías ser capaz de hacer

  1. Leer cualquier kernel Triton y explicar qué hace cada línea.
  2. Decidir, para un operador nuevo, si empezar en Triton o en CUDA C.
  3. Leer un bloque @triton.autotune y predecir qué configs escogerá probablemente el tuner en tamaños comunes.
  4. Usar triton.compile(...) para volcar la IR intermedia para depurar.

Lo que esta página NO cubre

  • Las primitivas matmul de Triton en profundidad. tl.dot y su ruta hacia Tensor Cores son un tema en sí mismo; no escribimos un kernel matmul en la Fase 24 (la softmax basta).
  • Características específicas de Triton 3.x (warp specialization, async TMA en Hopper). Fase 27 si es relevante.
  • Internals de Triton (las pasadas MLIR). Fuera de alcance; eres usuario, no compiler engineer.

Siguiente: theory/04-pytorch-as-substrate.md — la primera aparición de PyTorch en este currículo.