Skip to content

English · Español

05 — Recorrido del forward de Flash + la matemática del KV cache de GQA

🇪🇸 Dos derivaciones pegadas: (a) recorremos el forward de FlashAttention paso a paso con pseudocódigo tile-loop ejecutable mentalmente; (b) deducimos exactamente cuántos bytes de KV-cache ahorra GQA frente a MHA en función de n_heads, kv_heads, secuencia y dtype. Sin pseudocódigo abreviado y sin proverbios.

Anclajes: theory/01-online-softmax.md, theory/02-flash-attention.md, theory/04-gqa-mqa-mla.md. Este archivo es el único que deberías tener abierto al implementar el lab.


Parte A — Forward de Flash, completamente desplegado

Longitud de secuencia N, dimensión de cabeza d, tamaños de tile B_r filas de Q y B_c filas de K, V. Restricción: el working set de un tile cabe en SRAM.

Layout de memoria del working set

El estado residente en SRAM durante el bucle interno:

Q_i  : (B_r, d) fp32       — tile de query actual
K_j  : (B_c, d) fp32       — tile de key actual
V_j  : (B_c, d) fp32       — tile de value actual
S_ij : (B_r, B_c) fp32     — logits parciales, nunca escritos a HBM
m_i  : (B_r,) fp32         — máximo corriente por fila
l_i  : (B_r,) fp32         — suma corriente por fila (tras reescala)
O_i  : (B_r, d) fp32       — acumulador de salida corriente

Para d=64, B_r=B_c=64 en fp32: 4 × (64·64 + 3·64·64 + 64 + 64 + 64·64) = 4 × (4096 + 12288 + 128 + 4096) = 4 × 20608 ≈ 82 KiB. Justo en L1 clase Haswell; cómodo en SRAM de Hopper.

El bucle bloque-a-bloque (modelo mental ejecutable)

# Bucle externo: cada tile de Q se procesa una vez; K, V se streamean.
for i in range(ceil(N / B_r)):
    Q_i  = HBM_load(Q[i*B_r : (i+1)*B_r, :])           # (B_r, d), R: B_r * d
    O_i  = zeros((B_r, d))                              # en SRAM
    m_i  = full((B_r,), -inf)                           # en SRAM, máximo corriente
    l_i  = zeros((B_r,))                                # en SRAM, suma corriente

    # Bucle interno: streamea K_j, V_j una vez por (i, j).
    for j in range(ceil(N / B_c)):
        K_j  = HBM_load(K[j*B_c : (j+1)*B_c, :])       # R: B_c * d
        V_j  = HBM_load(V[j*B_c : (j+1)*B_c, :])       # R: B_c * d

        # Calcula los logits parciales en SRAM (nunca toca HBM).
        S_ij = (Q_i @ K_j.T) / sqrt(d)                  # (B_r, B_c)

        # Actualización del online softmax — derivada en teoría 01.
        m_new = max(m_i, row_max(S_ij))                 # (B_r,)
        alpha = exp(m_i - m_new)                        # (B_r,) — reescala O_i, l_i viejos
        P_ij  = exp(S_ij - m_new[:, None])              # (B_r, B_c)
        l_i   = alpha * l_i + row_sum(P_ij)             # (B_r,)
        O_i   = alpha[:, None] * O_i + P_ij @ V_j       # (B_r, d)
        m_i   = m_new

    # Fin del bucle interno: normaliza y escribe de vuelta.
    O[i*B_r : (i+1)*B_r, :] = O_i / l_i[:, None]        # W: B_r * d

Por qué O_i se reescala correctamente (el paso load-bearing)

En el paso j-1, O_i = sum_{j' < j} exp(s_{j'} - m_{j-1}) @ V_{j'}. Tras m_new:

\[ O_i^{\text{new}} = \sum_{j' \le j} \exp(s_{j'} - m_{\text{new}}) \, V_{j'} = \alpha \cdot O_i^{\text{old}} + \exp(S_{ij} - m_{\text{new}}) \, V_j \]

El factor alpha = exp(m_i - m_new) es exactamente lo que necesitas para "re-basar" el acumulador viejo desde m_i a m_new sin recomputar los tiles previos. Álgebra idéntica al update() de la teoría 01 para la suma corriente.

Contabilidad de tráfico HBM

Por iteración externa i:

  • Cargar Q_i: B_r · d elementos.
  • Cargar K_j y V_j para todo j: 2 · N · d elementos (cada tile leído una vez).
  • Escribir O_i: B_r · d.

Total sobre todo i:

\[ \text{elementos HBM} = \underbrace{N \cdot d}_{\text{Q, una vez total}} + \underbrace{(N/B_r) \cdot 2 \cdot N \cdot d}_{\text{K, V re-leídos por tile externo}} + \underbrace{N \cdot d}_{\text{O, una vez total}} = N d \cdot \left(2 + \frac{2N}{B_r}\right) \]

Para N=2048, d=64, B_r=64: 2048 · 64 · (2 + 64) = 8.6 M elementos = 34.4 MiB fp32. La attention naive cruza HBM en ≈ 4 N² + 3 N d = 16.8 M + 0.4 M = 17.2 M elementos = 68.8 MiB. Flash mueve la mitad de los bytes HBM — y la matriz S nunca toca HBM en absoluto, que es la ganancia arquitectónica mayor (jerarquía de caché + amplificación de lectura DRAM).

Nota para B_r → N (una sola iteración externa, todo Q en SRAM): HBM = 2Nd + 2Nd = 4Nd. Flash colapsa a un algoritmo de una sola pasada. Para N muy pequeño (≤ 256) este es el régimen; el overhead del tiling desaparece.

Qué cambia para el backward pass (preview)

El backward necesita S y P. Dos estrategias: (a) recomputar S desde Q, K (FLOPs baratos, memoria barata — elegido por Flash); (b) almacenar S en HBM (memoria cara, FLOPs gratis — elegido por naive). El backward de Flash 2 elige (a) con un par de factores de reescala extra. Fuera de alcance para esta fase; la derivación del forward por sí sola es la load-bearing.


Parte B — La matemática del KV cache de GQA, con números

La baseline MHA

Para multi-head attention con n_heads cabezas, dimensión de cabeza d_h = d_model / n_heads, y longitud de contexto N:

  • Forma del KV cache: (n_layers, 2, n_heads, N, d_h).
  • Bytes por token en el KV cache: 2 · n_layers · n_heads · d_h · sizeof(dtype) = 2 · n_layers · d_model · sizeof(dtype).

El "2" es un slot para K, uno para V. Independiente de n_heads una vez expresado como n_heads · d_h = d_model.

Para LLaMA-7B (n_layers=32, d_model=4096) en fp16:

\[ \text{bytes/token} = 2 \cdot 32 \cdot 4096 \cdot 2 = 524{,}288 \text{ bytes} = 512 \text{ KiB / token} \]

Un contexto N=4096: 4096 · 512 KiB = 2 GiB sólo para el KV cache. Por esto la inferencia de contexto largo es difícil.

GQA: agrupa queries, comparte K y V

GQA particiona las n_heads cabezas en n_groups = n_heads / kv_heads grupos, donde cada grupo de cabezas de query comparte una cabeza K, V.

Forma del KV cache: (n_layers, 2, kv_heads, N, d_h).

Bytes por token:

\[ \text{bytes/token}_{\text{GQA}} = 2 \cdot n_{\text{layers}} \cdot \underbrace{k_{\text{KV}} \cdot d_h}_{\text{anchura KV}} \cdot \text{sizeof(dtype)} \]

La ratio de ahorro:

\[ \frac{\text{bytes}_{\text{GQA}}}{\text{bytes}_{\text{MHA}}} = \frac{k_{\text{KV}}}{n_{\text{heads}}} \]

Para LLaMA-2-7B (n_heads=32, kv_heads=32) → MHA, sin ahorros. Para LLaMA-2-70B (n_heads=64, kv_heads=8) → GQA-8, 8× de reducción. KV/token cae de 1280 KiB (equivalente-MHA) a 160 KiB. Para Mistral-7B (n_heads=32, kv_heads=8) → GQA-8, 4× de reducción.

Por qué GQA no es gratis en calidad

Las cabezas de query en un grupo comparten un K, V — así que sólo pueden atender a las mismas ubicaciones, sólo con diferentes "pesos" vía la proyección Q. Esto es una restricción real de expresividad. El hallazgo empírico (Ainslie et al., 2023): en un modelo bien ajustado, kv_heads = n_heads / 8 es esencialmente indistinguible en calidad de MHA completo. Empujado más allá (kv_heads = 1 = MQA), el impacto de calidad es real — la mayoría de modelos de producción se detienen en n_heads / 8.

La frontera de la memoria de caché

Para un modelo sirviendo B usuarios concurrentes en contexto N:

\[ \text{KV cache total} = B \cdot N \cdot 2 \cdot n_{\text{layers}} \cdot k_{\text{KV}} \cdot d_h \cdot \text{sizeof(dtype)} \]

Esto es lo que GQA te compra: a (B, N, n_layers, d_h) fijos, ir de kv_heads = n_heads a kv_heads = n_heads / 8 te permite servir 8× más usuarios en el mismo presupuesto de KV cache. O servir a los mismos usuarios con 8× más contexto. O usar hardware 8× más barato.

El KV cache es el cuello de botella práctico a escala de inferencia de producción. Los FLOPs de attention escalan con ; la memoria del KV cache escala con N. Conforme los modelos van a contextos de 100K tokens, la memoria del KV cache domina todo lo demás. GQA es la mayor palanca arquitectónica para gestionar eso, salvo descartar el cacheo por completo (LoRA, trucos de contexto infinito, etc.).

Números a escala de Mini-GPT (para que la derivación no sea abstracta)

Para Mini-GPT (n_layers=2, n_heads=4, d_h=16, d_model=64, fp32) en N=64:

Variante kv_heads Bytes KV/token Bytes KV totales en N=64
MHA 4 2·2·4·16·4 = 1024 65 536 (64 KiB)
GQA-2 2 2·2·2·16·4 = 512 32 768 (32 KiB)
MQA 1 2·2·1·16·4 = 256 16 384 (16 KiB)

Una reducción 4× de MHA a MQA — pequeña en absoluto, pero es la misma ratio que escala a un gigabyte para LLaMA-2-70B.

Citas

  • Dao, Fu, Ermon, Rudra, Ré. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. arXiv:2205.14135.
  • Dao. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. 2023. arXiv:2307.08691.
  • Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón, Sanghai. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. arXiv:2305.13245.

Recap de un párrafo

El forward de Flash es un doble bucle: tile externo de Q, tile interno de K, V, con la recurrencia de online softmax en el paso interno que mantiene la matriz de logits parciales S residente en SRAM y evita ida y vuelta a HBM. El tráfico HBM cae de O(N²) a O(N²d / B_r) — la mitad a un tercio en la práctica — y S nunca cruza HBM en absoluto. GQA, en paralelo, ataca la memoria del KV cache: compartiendo K, V entre grupos de cabezas de query, recorta los bytes KV/token exactamente por kv_heads / n_heads. Las dos técnicas componen: Flash acelera la attention; GQA encoge el cache. Juntas hacen económicamente factible la inferencia de contexto largo (10K-100K tokens).

Siguiente: lab/04-mqa-gqa.md para la medición empírica.