English · Español
Lab 01 — El delta de bytes movidos entre flash attention y naive¶
Objetivo: derivar simbólicamente y medir empíricamente los bytes movidos por la attention naive vs flash attention. Plotear bytes vs N. Calcular el ratio de intensidad del roofline.
Tiempo estimado: 3–4 horas.
Prerrequisito: theory 02 leída; lab de softmax online commiteado.
Lo que produces¶
Un directorio experiments/27-flash-vs-naive-bytes/ que contenga:
derive_bytes.py— fórmulas cerradas de bytes movidos (sólo impresas; es un script de derivación, no un trabajo pesado de cómputo).measure_bytes.py— bytes movidos empíricamente, usando un profiler o una implementación de referencia instrumentada.bytes_vs_n.png— plot log-log de bytes movidos vs N para ambos esquemas.intensity_ratio.png— ratio de intensidad derivado flash/naive vs N.manifest.json.README.md.
TODOs¶
Bloque A — derivar bytes movidos simbólicos¶
Implementa funciones Python que devuelvan recuentos de bytes en forma cerrada:
-
bytes_naive(N, d, dtype_bytes=2)→ devuelve(12*N*d + 16*N**2) * (dtype_bytes / 4)(la fórmula de theory 02; escalada por dtype). Nota: son bytes HBM; ignora la escritura deO(igual en ambos). -
bytes_flash(N, d, B_r, B_c, dtype_bytes=2)→ devuelve8 * N * d * (1 + N / B_r) * (dtype_bytes / 4). - Imprime ambos para varias combinaciones
(N, d):(64, 64)(la longitud de secuencia del corpus de verbos — la ganancia de flash debería ser ~cero aquí),(1024, 64), (2048, 64), (4096, 64), (8192, 128), (32768, 128). - Anota en la salida del script: en
N=64, la matriz naiveN²=4096materializada cabe en 16 KiB a fp32 — dentro de L1. Flash no aporta nada. El objetivo de ejecutar esto para la secuencia de verbos es mostrar que flash sólo es una ganancia cuando N · N · dtype > L2.
Bloque B — intensidad simbólica¶
- FLOPs =
4 * N * N * d(el coste dominanteQ@K^T+P@V). - Intensidad para ambos. Imprime como tabla.
Bloque C — medir empíricamente (CPU)¶
Como el hardware local de Borja es CPU, no intentes medir en GPU aquí; ese es el trabajo del lab 02. Usa una implementación CPU:
- Implementa
attn_naive_cpu(Q, K, V)en PyTorch usando matmuls y softmax explícitos. Usatorch.profilerpara contar bytes movidos. - Implementa
attn_flash_reference_cpu(Q, K, V, B_r, B_c)en PyTorch puro (sin Triton — sólo el bucle de teselado). Perfila bytes. - Compara los bytes medidos con las predicciones simbólicas. Deberían concordar dentro de ~30% (sobrecarga de Python, layout, etc.).
Bloque D — plotear¶
- Plot log-log, eje x N de 256 a 16384 doblando, eje y bytes movidos. Dos líneas (naive, flash con
B_r=B_c=64). - Anota dónde flash se vuelve "mucho mejor" que naive (típicamente N ≥ 1024).
Bloque E — overlay del roofline (preview del lab 02)¶
- Calcula la intensidad para ambos en
N=2048, d=64en la máquina de Borja. Usa los techos del roofline medidos en la Fase 1. - Predice el ratio de speedup. Guárdalo como nota en
README.mdpara comparar con el speedup medido en GPU del lab 02.
Bloque F — interpretar en README.md¶
Tres preguntas:
- ¿A qué N empieza flash a ganar en bytes HBM movidos? Por debajo de cierto N, la sobrecarga del teselado puede dominar. El plot debería mostrar un cruce.
- Según la fórmula simbólica, ¿cuál es el ratio de intensidad flash/naive para
N=8192, d=128, B_r=B_c=64? Muestra el cálculo. - El paper de flash reclama 3× speedup en A100 a N=2048. Según tu ratio de intensidad, el speedup teórico limitado por roofline podría ser 5–10×. ¿Por qué el speedup real es menor? (Pista: saturación del ancho de banda SRAM, FLOPs no-matmul, sobrecarga del lanzamiento del kernel.)
Condiciones de parada¶
- Los seis archivos commiteados.
- Bytes simbólicos y medidos concuerdan dentro de 30%.
- Plots commiteados y etiquetados.
- README responde las tres preguntas del Bloque F.
Errores típicos¶
- La contabilidad de bytes incluye la escritura de
O. Ambos esquemas escribenOuna vez; es lo mismo. Réstalo para una comparación limpia. - Confusión fp16 vs fp32. El softmax naive de PyTorch hace internamente cast a fp32 y vuelta. La medida de bytes movidos debería reflejarlo (el intermedio fp32 domina).
- Flash de referencia en CPU es lento. No pasa nada; estás midiendo bytes, no segundos. Usa N pequeño (≤ 512) si va dolorosamente lento.
Conexión con el lab 02¶
Este lab te da la predicción. El lab 02 (kernel Triton en GPU) te da la medida. Guarda aquí tu predicción del ratio de intensidad; compruébala allí.
Cuándo consultar solutions/¶
Tras commitear los seis archivos y verificar que las predicciones concuerdan con el ratio de intensidad medido dentro de 20%.
Siguiente lab: lab/02-flash-triton.md.