Skip to content

English · Español

Lab 01 — El delta de bytes movidos entre flash attention y naive

Objetivo: derivar simbólicamente y medir empíricamente los bytes movidos por la attention naive vs flash attention. Plotear bytes vs N. Calcular el ratio de intensidad del roofline.

Tiempo estimado: 3–4 horas.

Prerrequisito: theory 02 leída; lab de softmax online commiteado.


Lo que produces

Un directorio experiments/27-flash-vs-naive-bytes/ que contenga:

  • derive_bytes.py — fórmulas cerradas de bytes movidos (sólo impresas; es un script de derivación, no un trabajo pesado de cómputo).
  • measure_bytes.py — bytes movidos empíricamente, usando un profiler o una implementación de referencia instrumentada.
  • bytes_vs_n.png — plot log-log de bytes movidos vs N para ambos esquemas.
  • intensity_ratio.png — ratio de intensidad derivado flash/naive vs N.
  • manifest.json.
  • README.md.

TODOs

Bloque A — derivar bytes movidos simbólicos

Implementa funciones Python que devuelvan recuentos de bytes en forma cerrada:

  • bytes_naive(N, d, dtype_bytes=2) → devuelve (12*N*d + 16*N**2) * (dtype_bytes / 4) (la fórmula de theory 02; escalada por dtype). Nota: son bytes HBM; ignora la escritura de O (igual en ambos).
  • bytes_flash(N, d, B_r, B_c, dtype_bytes=2) → devuelve 8 * N * d * (1 + N / B_r) * (dtype_bytes / 4).
  • Imprime ambos para varias combinaciones (N, d): (64, 64) (la longitud de secuencia del corpus de verbos — la ganancia de flash debería ser ~cero aquí), (1024, 64), (2048, 64), (4096, 64), (8192, 128), (32768, 128).
  • Anota en la salida del script: en N=64, la matriz naive N²=4096 materializada cabe en 16 KiB a fp32 — dentro de L1. Flash no aporta nada. El objetivo de ejecutar esto para la secuencia de verbos es mostrar que flash sólo es una ganancia cuando N · N · dtype > L2.

Bloque B — intensidad simbólica

  • FLOPs = 4 * N * N * d (el coste dominante Q@K^T + P@V).
  • Intensidad para ambos. Imprime como tabla.

Bloque C — medir empíricamente (CPU)

Como el hardware local de Borja es CPU, no intentes medir en GPU aquí; ese es el trabajo del lab 02. Usa una implementación CPU:

  • Implementa attn_naive_cpu(Q, K, V) en PyTorch usando matmuls y softmax explícitos. Usa torch.profiler para contar bytes movidos.
  • Implementa attn_flash_reference_cpu(Q, K, V, B_r, B_c) en PyTorch puro (sin Triton — sólo el bucle de teselado). Perfila bytes.
  • Compara los bytes medidos con las predicciones simbólicas. Deberían concordar dentro de ~30% (sobrecarga de Python, layout, etc.).

Bloque D — plotear

  • Plot log-log, eje x N de 256 a 16384 doblando, eje y bytes movidos. Dos líneas (naive, flash con B_r=B_c=64).
  • Anota dónde flash se vuelve "mucho mejor" que naive (típicamente N ≥ 1024).

Bloque E — overlay del roofline (preview del lab 02)

  • Calcula la intensidad para ambos en N=2048, d=64 en la máquina de Borja. Usa los techos del roofline medidos en la Fase 1.
  • Predice el ratio de speedup. Guárdalo como nota en README.md para comparar con el speedup medido en GPU del lab 02.

Bloque F — interpretar en README.md

Tres preguntas:

  1. ¿A qué N empieza flash a ganar en bytes HBM movidos? Por debajo de cierto N, la sobrecarga del teselado puede dominar. El plot debería mostrar un cruce.
  2. Según la fórmula simbólica, ¿cuál es el ratio de intensidad flash/naive para N=8192, d=128, B_r=B_c=64? Muestra el cálculo.
  3. El paper de flash reclama 3× speedup en A100 a N=2048. Según tu ratio de intensidad, el speedup teórico limitado por roofline podría ser 5–10×. ¿Por qué el speedup real es menor? (Pista: saturación del ancho de banda SRAM, FLOPs no-matmul, sobrecarga del lanzamiento del kernel.)

Condiciones de parada

  • Los seis archivos commiteados.
  • Bytes simbólicos y medidos concuerdan dentro de 30%.
  • Plots commiteados y etiquetados.
  • README responde las tres preguntas del Bloque F.

Errores típicos

  • La contabilidad de bytes incluye la escritura de O. Ambos esquemas escriben O una vez; es lo mismo. Réstalo para una comparación limpia.
  • Confusión fp16 vs fp32. El softmax naive de PyTorch hace internamente cast a fp32 y vuelta. La medida de bytes movidos debería reflejarlo (el intermedio fp32 domina).
  • Flash de referencia en CPU es lento. No pasa nada; estás midiendo bytes, no segundos. Usa N pequeño (≤ 512) si va dolorosamente lento.

Conexión con el lab 02

Este lab te da la predicción. El lab 02 (kernel Triton en GPU) te da la medida. Guarda aquí tu predicción del ratio de intensidad; compruébala allí.

Cuándo consultar solutions/

Tras commitear los seis archivos y verificar que las predicciones concuerdan con el ratio de intensidad medido dentro de 20%.


Siguiente lab: lab/02-flash-triton.md.