Skip to content

English · Español

Lab 02 — Kernel de softmax fusionado tuneado (≥30% de cuBLAS)

Objetivo: aplicar la escalera de optimizaciones de theory/02 al kernel ingenuo del lab 01. Subir desde ~1% del pico HBM hasta ≥30% del rendimiento de torch.nn.functional.softmax con \(B = 512, V = 600\). Capturar un perfil ncu de la versión final. Este es el lab donde se alcanza el objetivo de rendimiento del DoD de la Fase 24.

Tiempo estimado: 6–10 horas (el tuning de kernels es iterativo).

Prerrequisito: lab/01-naive-kernel.md completo (existe la línea base ingenua correcta). nsight-compute (ncu) instalado en la GPU en la nube.


Lo que produces

Un directorio experiments/24-tuned-kernel/ y un src/minikernel/ actualizado:

  • src/minikernel/softmax_smem.cu — versión con coalescing + SMEM.
  • src/minikernel/softmax_fused.cu — versión con reducción paralela + online softmax (la versión ≥30% de cuBLAS).
  • src/minikernel/softmax.pysoftmax(x) público que despacha a fused → smem → naive → numpy por orden de disponibilidad.
  • tests/test_softmax_tuned.py — tests de equivalencia para ambos kernels nuevos.
  • experiments/24-tuned-kernel/bench.py — cara a cara: naive vs smem vs fused vs F.softmax, a lo largo del barrido de \(B\).
  • experiments/24-tuned-kernel/ncu_report.txt — perfil ncu anotado de la versión fused.
  • experiments/24-tuned-kernel/manifest.json.
  • experiments/24-tuned-kernel/README.md — 3 párrafos: qué consiguió cada movimiento, la interpretación del ncu, la brecha residual con cuBLAS / F.softmax.

TODOs

Bloque A — versión 2 (coalescing + SMEM)

  • Según theory/02 §"Version 2": un bloque por fila, los threads cargan cooperativamente la fila en extern __shared__ float row[], sincronizan, luego reducen en serie en el thread 0 (manteniéndolo simple). Escritura final con coalescing.
  • Elige el tamaño de bloque: \(\geq V\) redondeado a potencia de 2 (así 1024). Lanza con <<<B, 1024, V * sizeof(float)>>>.
  • Testea la corrección como en el lab 01.
  • Bench: espera 10–20% del pico HBM. Documenta el salto desde naive.

Bloque B — versión 3 (reducción paralela + online softmax)

  • Sustituye el max/sum serial en if (tid == 0) por reducciones en árbol a lo largo del bloque (el patrón canónico de phase-23/theory/03).
  • Fusiona la pasada max y la pasada sum usando online softmax (theory/02 §"Version 3"). Una pasada por la fila lee y acumula tanto \(m\) como \(s\).
  • Testea la corrección — nota: la online softmax puede derivar numéricamente vs la de 3 pasadas en fp32; puede que la tolerancia tenga que ser 1e-4 en vez de 1e-5. Documéntalo.
  • Bench: espera 30–50% del pico HBM. Esta es la versión relevante para el DoD.

Bloque C — perfil ncu

  • En la GPU en la nube: ncu --set full --section MemoryWorkloadAnalysis --section ComputeWorkloadAnalysis --section Occupancy --target-processes all python bench.py. Guarda el informe en ncu_report.ncu-rep y una exportación de texto en ncu_report.txt.
  • Lee el informe. Identifica:
  • Ocupación alcanzada vs teórica (según compute capability).
  • Throughput HBM vs pico (según la spec del device).
  • Hit rate de L1 / SMEM.
  • Motivos de stall (Memory Throttle, Memory Dependency, Execution Dependency, etc.).
  • Escribe una anotación de 1 párrafo en README.md. Identifica el motivo dominante de stall. Si no es "Memory Throttle" (es decir, no está limitado por HBM), algo va mal (el kernel se supone que está limitado por memoria).

Bloque D — comparar con cuBLAS / F.softmax

  • En bench.py, cronometra también torch.nn.functional.softmax(x, dim=-1) con el mismo \((B, V)\). PyTorch despacha al softmax de cuDNN o al JIT inductor (según la versión de torch + warm-up).
  • Calcula: tuned_kernel_time / F_softmax_time. Objetivo: ≤ 3.0 (es decir, tu kernel es como mínimo ⅓ de la velocidad). Si alcanzas ≤ 1.5, bien — a veces un kernel custom fusionado supera a uno genérico con \(V\) pequeño.
  • Documenta la brecha. No te obsesiones por cerrarla; entiéndela.

Bloque E — manifest

{
  "experiment": "24-tuned-kernel",
  "date": "YYYY-MM-DD",
  "seed": 42,
  "gpu": {"model": null, "compute_capability": null, "hbm_peak_gbs": null},
  "versions": {"python": "3.11.x", "cupy": null, "torch": null, "ncu": null},
  "kernels": {
    "naive": {"median_us_at_B512": null, "fraction_of_peak": null},
    "smem":  {"median_us_at_B512": null, "fraction_of_peak": null},
    "fused": {"median_us_at_B512": null, "fraction_of_peak": null},
    "F_softmax_ref": {"median_us_at_B512": null}
  },
  "results_summary": {
    "fused_vs_F_softmax_ratio": null,
    "dod_30pct_met": null,
    "dominant_stall_reason": null
  }
}

Restricciones

  • Corrección primero. No pruebes el Bloque B antes de que el Bloque A pase la corrección.
  • Un cambio cada vez. Pasar de naive → SMEM → fused → online son cuatro movimientos. Haz bench entre cada uno — saber qué te dio cada movimiento es la lección.
  • fp32 en todo. fp16 / bf16 es Bloque F opcional (más abajo); no requerido para el DoD.
  • Fija la semilla. Todas las mediciones reproducibles.

Bloque F opcional — ruta fp16

  • Repite el kernel fused con entradas fp16, acumulador fp32. Tolerancia vs F.softmax(fp16): 1e-2.
  • Bench: 1.5–2× de speedup (limitado por memoria — bytes a la mitad ≈ tiempo a la mitad).
  • Si el dispatcher rutea por dtype, añade rama fp16. Si no, mantén fp16 en un softmax_fused_fp16.cu aparte.

Condiciones de parada

Hecho cuando:

  1. La versión SMEM pasa corrección; punto del bench registrado.
  2. La versión fused (online + paralela) pasa corrección; punto del bench registrado.
  3. Fused vs F.softmax con \(B=512, V=600\): ratio ≤ 3.0 (es decir, ≥33% del rendimiento de F.softmax — cumple el DoD).
  4. ncu_report.txt commiteado con anotación que identifica el motivo dominante de stall.
  5. manifest.json commiteado.
  6. README.md documenta qué consiguió cada movimiento de optimización (con números).

Escollos

  • Deriva numérica de la online softmax. La recurrencia actualiza \(m\) y \(s\) al unísono; el orden de las operaciones importa. Implementación estándar: ver Milakov & Gimelshein 2018 ("Online Normalizer Calculation"). Iguala el orden del paper exactamente para igualar el comportamiento de referencia.
  • Conflictos de banco SMEM. 32 bancos; el patrón de acceso row[tid] con tid recorriendo 0..1023 mapea al banco tid % 32. Para \(V = 600\) < 1024, los threads de la cola están inactivos — sin conflictos. Para \(V\) mayor, puede hacer falta padding. No relevante a escala de gramática.
  • F.softmax más rápido que tu kernel porque fusiona con el GEMM aguas arriba. El inductor de PyTorch a veces fusiona lm_head + softmax en un solo kernel. Si estás comparando contra una referencia fused por inductor, comparas peras con manzanas. Usa torch.nn.functional.softmax directamente con @torch.compile(mode='reduce-overhead') desactivado.
  • Register spill por un tamaño de bloque demasiado grande. --ptxas-options=-v reporta el uso de registros. Si > 64 regs/thread y la ocupación es baja, baja el tamaño de bloque.

Cuándo consultar solutions/

Tras cumplir todas las condiciones de parada. La referencia recorre la secuencia exacta de movimientos (y los números que cada uno alcanza en una A10).


Siguiente lab: lab/03-triton-and-pytorch.md.