English · Español
01 — Paralelismo de datos (data parallel), ZeRO y FSDP¶
🇪🇸 La familia "cada GPU tiene una réplica (o un trozo) del modelo y un batch distinto". DDP es la versión ingenua: réplica completa, all-reduce de gradientes. ZeRO-½/3 va recortando estado por GPU (optimizador → gradientes → parámetros). FSDP es la encarnación moderna de ZeRO-3 en PyTorch. La consigna: cuanto menos por GPU, más comunicación.
En 00-motivation.md dijimos: distribuye o bien porque el modelo no cabe, o bien porque el cómputo tarda demasiado. La familia paralelismo de datos (data parallel) ataca principalmente el segundo problema — muchas GPUs comparten el trabajo de un batch — pero ZeRO-⅔ y FSDP incorporan el alivio de presión de memoria sharding el estado entre los mismos workers.
Esta página son las mecánicas, los patrones de comunicación y los tradeoffs de coste.
DDP — el baseline¶
En DistributedDataParallel (DDP), cada GPU mantiene una copia idéntica de los parámetros del modelo \(\theta\), el estado del optimizador y el buffer de gradientes. Cada step:
- El batch global de tamaño \(B\) se divide en shards por worker de tamaño \(B/N\) (donde \(N\) = número de workers).
- Cada worker hace forward + backward sobre su shard, produciendo gradientes locales \(g_i\).
- Los workers ejecutan un all-reduce sobre los \(g_i\), de modo que cada worker termina con \(\bar{g} = (1/N) \sum_i g_i\).
- Cada worker aplica la misma actualización del optimizador con \(\bar{g}\). Como empezaron idénticos y vieron \(\bar{g}\) idéntico, siguen idénticos. Determinismo preservado.
El step de comunicación (3) es la única comunicación inter-worker por step de entrenamiento. Detalle de implementación: el framework ejecuta el all-reduce bucketed — dispara el reduce sobre cada grupo de parámetros tan pronto como su gradiente esté listo, solapando comunicación con el backward restante.
Volumen de comunicación por step¶
Un ring all-reduce de \(|\theta|\) parámetros en fp32 envía \(2 \cdot (N-1)/N \cdot 4 \cdot |\theta|\) bytes por worker. Para \(N\) grande, esto se aproxima a \(8 \cdot |\theta|\) bytes por worker, independientemente de \(N\). Esa independencia es lo que hace que DDP escale con elegancia — hasta que saturas la red.
Para MiniGPT-grammar (el modelo grammar-tutor de la Fase 17 — ~500k params), \(8 \cdot 500\text{k} = 4\) MB por step. En un enlace de 10 Gbps eso son ~3 ms de comunicación. A 50 ms/step de cómputo, eso es un 6% de overhead. Bien.
Para un modelo de 7B parámetros, eso son \(8 \cdot 7\text{B} = 56\) GB por step. Incluso en InfiniBand de 100 Gbps (~12 GB/s efectivos), eso son 4,7 segundos de comunicación por step. A esa escala, DDP solo ya no es suficiente. Necesitas o bien reducir el volumen de comunicación por worker (ZeRO-½/3 reducen el estado por worker) o cambiar la topología (all-reduce jerárquico entre nodos vs intra-nodo).
Salvedad sobre determinismo¶
DDP es bit-exactamente determinista entre workers (misma semilla → mismo output), pero no bit-exactamente idéntico al entrenamiento single-GPU del mismo batch global. La suma en punto flotante no es asociativa; el orden en que los elementos del gradiente se suman durante el all-reduce difiere del orden en que una sola GPU los sumaría. El criterio de aceptación para "DDP equivalente a single-GPU" es "≤ 1e-5 de deriva en logits", no "equivalente a nivel de byte".
ZeRO — shardear la redundancia¶
DDP desperdicia memoria: cada GPU tiene una copia completa de todo. La familia ZeRO (Zero Redundancy Optimizer, de Microsoft DeepSpeed) observa que parte de esa redundancia es innecesaria si estás dispuesto a hacer más comunicación.
Qué hay en la memoria de cada GPU bajo DDP¶
Para un modelo con \(|\theta|\) parámetros usando Adam con precisión mixta:
- Pesos (fp16): \(2 |\theta|\) bytes — necesarios para forward + backward.
- Gradientes (fp16): \(2 |\theta|\) bytes — producidos por el backward.
- Pesos maestros (fp32): \(4 |\theta|\) bytes — necesarios para las actualizaciones de Adam.
- Momentum de Adam (fp32): \(4 |\theta|\) bytes.
- Varianza de Adam (fp32): \(4 |\theta|\) bytes.
Total: \(16 |\theta|\) bytes por GPU. Para 7B params: 112 GB. No cabe en una A100-80GB.
Qué hace ZeRO¶
Cada stage de ZeRO shardea una categoría más entre los \(N\) workers:
- ZeRO-1: estado del optimizador (pesos maestros + momentum + varianza) — \(12 |\theta| / N\). Cada worker mantiene \(4 |\theta| / N\) de pesos maestros y junta lo que necesita para aplicar la actualización.
- ZeRO-2: + gradientes shardeados. Cada worker mantiene \(2 |\theta| / N\) de gradientes. Tras el backward, un reduce-scatter distribuye un shard entero de gradientes totalmente reducidos a cada worker.
- ZeRO-3: + parámetros shardeados. Cada worker mantiene \(2 |\theta| / N\) de pesos. Antes de cada capa de forward, un all-gather trae los pesos completos de esa capa, calcula, libera. Mismo patrón para backward.
Memoria bajo ZeRO-3 para 7B params en \(N=8\): \(16 \cdot 7\text{B} / 8 \approx 14\) GB por GPU. Cabe en una GPU de consumo de 24 GB.
El coste de comunicación crece¶
ZeRO-3 / FSDP duplica la comunicación frente a DDP:
- DDP: 1 all-reduce por step sobre \(|\theta|\).
- ZeRO-3 forward: \(L\) all-gathers sobre \(|\theta_\ell|\) — uno por capa.
- ZeRO-3 backward: \(L\) all-gathers sobre \(|\theta_\ell|\) + 1 reduce-scatter sobre \(|\theta|\).
Volumen neto de comunicación ≈ \(3 \cdot |\theta|\) bytes por worker (gather forward + gather backward + reduce-scatter), frente a \(\approx 2 \cdot |\theta|\) de DDP. El overhead de factor 1,5 te compra 1/N de memoria. Merece la pena cuando la memoria era el cuello de botella.
FSDP — la implementación de ZeRO-3 en PyTorch¶
Fully Sharded Data Parallel en PyTorch es esencialmente ZeRO-3 con una API más limpia. El modelo conceptual es idéntico; la implementación tiene unas pocas características de calidad de vida:
- Buffers de parámetros planos (flat). FSDP agrupa parámetros en tensores 1D "planos" por unidad
FSDP, lo que hace del all-gather una única comunicación contigua en lugar de una por parámetro. - Prefetch. Mientras computa el forward de la capa \(\ell\), FSDP emite el all-gather para los parámetros de la capa \(\ell + 1\) de forma asíncrona. La comunicación se solapa con el cómputo.
- CPU offload. Opcional: los pesos maestros y el estado del optimizador pueden vivir en RAM de CPU, paginados a GPU solo cuando se necesitan. Cambia ancho de banda por memoria.
auto_wrap_policy. Especifica qué submódulos se convierten en unidades FSDP. Una wrap policy de "cada bloque transformer" suele ser correcta.
El lab de "lectura anotada" para la Fase 35 (lab/03-megatron-fsdp-reading.md) recorre la lógica de prefetch de flat-param en torch/distributed/fsdp/_runtime_utils.py, que es la pieza menos obvia.
¿Cuándo es correcta la familia data-parallel?¶
| Situación | Recomendación |
|---|---|
| El modelo cabe en una GPU, quiero entrenar más rápido | DDP |
| El modelo cabe en una GPU pero el estado del optimizador no | ZeRO-1 |
| El modelo no cabe ni en fp16 | ZeRO-3 / FSDP |
| Inferencia (sin gradientes, sin estado del optimizador) | Replicación estilo DDP vale hasta que el modelo supere una GPU; luego TP |
| Entrenamiento distribuido entre N nodos con ancho de banda limitado | DDP + ZeRO-1 dentro de cada nodo; restringe ZeRO-3 a intra-nodo |
| Quiero depurar, no me importa el rendimiento | Single-GPU. Empieza siempre single-GPU. |
Para el grammar tutor de este currículo:
- El modelo es microscópico — cabe en una calculadora. Distribuirlo es educativo, no necesario.
- El lab 01 usa DDP entre 2 procesos CPU para enseñar el protocolo de cable: cómo funciona
init_process_group, qué pinta tiene un all-reduce entorch.distributed, qué hace NCCL/gloo por debajo. - No corremos ZeRO-3 / FSDP sobre el grammar tutor — no hay nada que shardear.
El ejercicio a futuro: "si el vocabulario del grammar tutor creciera de 600 formas a 600k formas (inglés + español + francés + alemán + italiano + portugués)," ¿cuándo la propia tabla de embeddings supera una GPU? Respuesta: con \(d_{\text{model}} = 4096\), 600k tokens × 4096 × 4 bytes = 10 GB solo para la tabla de embeddings. Ahí es cuando shardear la tabla de embeddings (una variante específica de TP, ver theory/02-parallelism-flavors.md) se vuelve la primera cosa a la que recurrir.
Lo que esta fase NO cubre¶
- Implementar ZeRO-3 / FSDP desde cero. FSDP de PyTorch son ~3000 LOC de concurrencia cuidadosa. Lo leemos; no lo reescribimos. El lab 01 de la Fase 35 implementa solo DDP, e incluso eso es mayormente un wrapper fino sobre
torch.distributed. - Tuning de CPU offload. Un juego de tuning específico de FSDP. Fuera de alcance; mencionado para vocabulario.
- Paralelismo 3D. DDP × TP × PP. Mencionado en
00-motivation.md; mecánica diferida a una hipotética Fase 41+. - Workers heterogéneos. Todos los workers son idénticos en esta fase. Cargas asimétricas (GPU + CPU mezclados, o GPU + TPU) son un área de investigación activa, irrelevante a este presupuesto.
Siguiente: theory/02-parallelism-flavors.md.