English · Español
Lab 02 — Kernel de softmax fusionado tuneado (≥30% de cuBLAS)¶
Objetivo: aplicar la escalera de optimizaciones de
theory/02al kernel ingenuo del lab 01. Subir desde ~1% del pico HBM hasta ≥30% del rendimiento detorch.nn.functional.softmaxcon \(B = 512, V = 600\). Capturar un perfilncude la versión final. Este es el lab donde se alcanza el objetivo de rendimiento del DoD de la Fase 24.Tiempo estimado: 6–10 horas (el tuning de kernels es iterativo).
Prerrequisito:
lab/01-naive-kernel.mdcompleto (existe la línea base ingenua correcta).nsight-compute(ncu) instalado en la GPU en la nube.
Lo que produces¶
Un directorio experiments/24-tuned-kernel/ y un src/minikernel/ actualizado:
src/minikernel/softmax_smem.cu— versión con coalescing + SMEM.src/minikernel/softmax_fused.cu— versión con reducción paralela + online softmax (la versión ≥30% de cuBLAS).src/minikernel/softmax.py—softmax(x)público que despacha a fused → smem → naive → numpy por orden de disponibilidad.tests/test_softmax_tuned.py— tests de equivalencia para ambos kernels nuevos.experiments/24-tuned-kernel/bench.py— cara a cara: naive vs smem vs fused vsF.softmax, a lo largo del barrido de \(B\).experiments/24-tuned-kernel/ncu_report.txt— perfilncuanotado de la versión fused.experiments/24-tuned-kernel/manifest.json.experiments/24-tuned-kernel/README.md— 3 párrafos: qué consiguió cada movimiento, la interpretación delncu, la brecha residual con cuBLAS /F.softmax.
TODOs¶
Bloque A — versión 2 (coalescing + SMEM)¶
- Según
theory/02§"Version 2": un bloque por fila, los threads cargan cooperativamente la fila enextern __shared__ float row[], sincronizan, luego reducen en serie en el thread 0 (manteniéndolo simple). Escritura final con coalescing. - Elige el tamaño de bloque: \(\geq V\) redondeado a potencia de 2 (así 1024). Lanza con
<<<B, 1024, V * sizeof(float)>>>. - Testea la corrección como en el lab 01.
- Bench: espera 10–20% del pico HBM. Documenta el salto desde naive.
Bloque B — versión 3 (reducción paralela + online softmax)¶
- Sustituye el max/sum serial en
if (tid == 0)por reducciones en árbol a lo largo del bloque (el patrón canónico dephase-23/theory/03). - Fusiona la pasada max y la pasada sum usando online softmax (theory/02 §"Version 3"). Una pasada por la fila lee y acumula tanto \(m\) como \(s\).
- Testea la corrección — nota: la online softmax puede derivar numéricamente vs la de 3 pasadas en fp32; puede que la tolerancia tenga que ser 1e-4 en vez de 1e-5. Documéntalo.
- Bench: espera 30–50% del pico HBM. Esta es la versión relevante para el DoD.
Bloque C — perfil ncu¶
- En la GPU en la nube:
ncu --set full --section MemoryWorkloadAnalysis --section ComputeWorkloadAnalysis --section Occupancy --target-processes all python bench.py. Guarda el informe enncu_report.ncu-repy una exportación de texto enncu_report.txt. - Lee el informe. Identifica:
- Ocupación alcanzada vs teórica (según compute capability).
- Throughput HBM vs pico (según la spec del device).
- Hit rate de L1 / SMEM.
- Motivos de stall (Memory Throttle, Memory Dependency, Execution Dependency, etc.).
- Escribe una anotación de 1 párrafo en
README.md. Identifica el motivo dominante de stall. Si no es "Memory Throttle" (es decir, no está limitado por HBM), algo va mal (el kernel se supone que está limitado por memoria).
Bloque D — comparar con cuBLAS / F.softmax¶
- En
bench.py, cronometra tambiéntorch.nn.functional.softmax(x, dim=-1)con el mismo \((B, V)\). PyTorch despacha al softmax decuDNNo al JIT inductor (según la versión de torch + warm-up). - Calcula:
tuned_kernel_time / F_softmax_time. Objetivo: ≤ 3.0 (es decir, tu kernel es como mínimo ⅓ de la velocidad). Si alcanzas ≤ 1.5, bien — a veces un kernel custom fusionado supera a uno genérico con \(V\) pequeño. - Documenta la brecha. No te obsesiones por cerrarla; entiéndela.
Bloque E — manifest¶
{
"experiment": "24-tuned-kernel",
"date": "YYYY-MM-DD",
"seed": 42,
"gpu": {"model": null, "compute_capability": null, "hbm_peak_gbs": null},
"versions": {"python": "3.11.x", "cupy": null, "torch": null, "ncu": null},
"kernels": {
"naive": {"median_us_at_B512": null, "fraction_of_peak": null},
"smem": {"median_us_at_B512": null, "fraction_of_peak": null},
"fused": {"median_us_at_B512": null, "fraction_of_peak": null},
"F_softmax_ref": {"median_us_at_B512": null}
},
"results_summary": {
"fused_vs_F_softmax_ratio": null,
"dod_30pct_met": null,
"dominant_stall_reason": null
}
}
Restricciones¶
- Corrección primero. No pruebes el Bloque B antes de que el Bloque A pase la corrección.
- Un cambio cada vez. Pasar de naive → SMEM → fused → online son cuatro movimientos. Haz bench entre cada uno — saber qué te dio cada movimiento es la lección.
- fp32 en todo. fp16 / bf16 es Bloque F opcional (más abajo); no requerido para el DoD.
- Fija la semilla. Todas las mediciones reproducibles.
Bloque F opcional — ruta fp16¶
- Repite el kernel fused con entradas fp16, acumulador fp32. Tolerancia vs
F.softmax(fp16): 1e-2. - Bench: 1.5–2× de speedup (limitado por memoria — bytes a la mitad ≈ tiempo a la mitad).
- Si el dispatcher rutea por dtype, añade rama fp16. Si no, mantén fp16 en un
softmax_fused_fp16.cuaparte.
Condiciones de parada¶
Hecho cuando:
- La versión SMEM pasa corrección; punto del bench registrado.
- La versión fused (online + paralela) pasa corrección; punto del bench registrado.
- Fused vs
F.softmaxcon \(B=512, V=600\): ratio ≤ 3.0 (es decir, ≥33% del rendimiento deF.softmax— cumple el DoD). ncu_report.txtcommiteado con anotación que identifica el motivo dominante de stall.manifest.jsoncommiteado.README.mddocumenta qué consiguió cada movimiento de optimización (con números).
Escollos¶
- Deriva numérica de la online softmax. La recurrencia actualiza \(m\) y \(s\) al unísono; el orden de las operaciones importa. Implementación estándar: ver Milakov & Gimelshein 2018 ("Online Normalizer Calculation"). Iguala el orden del paper exactamente para igualar el comportamiento de referencia.
- Conflictos de banco SMEM. 32 bancos; el patrón de acceso
row[tid]con tid recorriendo 0..1023 mapea al bancotid % 32. Para \(V = 600\) < 1024, los threads de la cola están inactivos — sin conflictos. Para \(V\) mayor, puede hacer falta padding. No relevante a escala de gramática. F.softmaxmás rápido que tu kernel porque fusiona con el GEMM aguas arriba. Elinductorde PyTorch a veces fusionalm_head + softmaxen un solo kernel. Si estás comparando contra una referencia fused por inductor, comparas peras con manzanas. Usatorch.nn.functional.softmaxdirectamente con@torch.compile(mode='reduce-overhead')desactivado.- Register spill por un tamaño de bloque demasiado grande.
--ptxas-options=-vreporta el uso de registros. Si > 64 regs/thread y la ocupación es baja, baja el tamaño de bloque.
Cuándo consultar solutions/¶
Tras cumplir todas las condiciones de parada. La referencia recorre la secuencia exacta de movimientos (y los números que cada uno alcanza en una A10).
Siguiente lab: lab/03-triton-and-pytorch.md.