Skip to content

English · Español

01 — Paralelismo de datos (data parallel), ZeRO y FSDP

🇪🇸 La familia "cada GPU tiene una réplica (o un trozo) del modelo y un batch distinto". DDP es la versión ingenua: réplica completa, all-reduce de gradientes. ZeRO-½/3 va recortando estado por GPU (optimizador → gradientes → parámetros). FSDP es la encarnación moderna de ZeRO-3 en PyTorch. La consigna: cuanto menos por GPU, más comunicación.

En 00-motivation.md dijimos: distribuye o bien porque el modelo no cabe, o bien porque el cómputo tarda demasiado. La familia paralelismo de datos (data parallel) ataca principalmente el segundo problema — muchas GPUs comparten el trabajo de un batch — pero ZeRO-⅔ y FSDP incorporan el alivio de presión de memoria sharding el estado entre los mismos workers.

Esta página son las mecánicas, los patrones de comunicación y los tradeoffs de coste.


DDP — el baseline

En DistributedDataParallel (DDP), cada GPU mantiene una copia idéntica de los parámetros del modelo \(\theta\), el estado del optimizador y el buffer de gradientes. Cada step:

  1. El batch global de tamaño \(B\) se divide en shards por worker de tamaño \(B/N\) (donde \(N\) = número de workers).
  2. Cada worker hace forward + backward sobre su shard, produciendo gradientes locales \(g_i\).
  3. Los workers ejecutan un all-reduce sobre los \(g_i\), de modo que cada worker termina con \(\bar{g} = (1/N) \sum_i g_i\).
  4. Cada worker aplica la misma actualización del optimizador con \(\bar{g}\). Como empezaron idénticos y vieron \(\bar{g}\) idéntico, siguen idénticos. Determinismo preservado.

El step de comunicación (3) es la única comunicación inter-worker por step de entrenamiento. Detalle de implementación: el framework ejecuta el all-reduce bucketed — dispara el reduce sobre cada grupo de parámetros tan pronto como su gradiente esté listo, solapando comunicación con el backward restante.

Volumen de comunicación por step

Un ring all-reduce de \(|\theta|\) parámetros en fp32 envía \(2 \cdot (N-1)/N \cdot 4 \cdot |\theta|\) bytes por worker. Para \(N\) grande, esto se aproxima a \(8 \cdot |\theta|\) bytes por worker, independientemente de \(N\). Esa independencia es lo que hace que DDP escale con elegancia — hasta que saturas la red.

Para MiniGPT-grammar (el modelo grammar-tutor de la Fase 17 — ~500k params), \(8 \cdot 500\text{k} = 4\) MB por step. En un enlace de 10 Gbps eso son ~3 ms de comunicación. A 50 ms/step de cómputo, eso es un 6% de overhead. Bien.

Para un modelo de 7B parámetros, eso son \(8 \cdot 7\text{B} = 56\) GB por step. Incluso en InfiniBand de 100 Gbps (~12 GB/s efectivos), eso son 4,7 segundos de comunicación por step. A esa escala, DDP solo ya no es suficiente. Necesitas o bien reducir el volumen de comunicación por worker (ZeRO-½/3 reducen el estado por worker) o cambiar la topología (all-reduce jerárquico entre nodos vs intra-nodo).

Salvedad sobre determinismo

DDP es bit-exactamente determinista entre workers (misma semilla → mismo output), pero no bit-exactamente idéntico al entrenamiento single-GPU del mismo batch global. La suma en punto flotante no es asociativa; el orden en que los elementos del gradiente se suman durante el all-reduce difiere del orden en que una sola GPU los sumaría. El criterio de aceptación para "DDP equivalente a single-GPU" es "≤ 1e-5 de deriva en logits", no "equivalente a nivel de byte".


ZeRO — shardear la redundancia

DDP desperdicia memoria: cada GPU tiene una copia completa de todo. La familia ZeRO (Zero Redundancy Optimizer, de Microsoft DeepSpeed) observa que parte de esa redundancia es innecesaria si estás dispuesto a hacer más comunicación.

Qué hay en la memoria de cada GPU bajo DDP

Para un modelo con \(|\theta|\) parámetros usando Adam con precisión mixta:

  • Pesos (fp16): \(2 |\theta|\) bytes — necesarios para forward + backward.
  • Gradientes (fp16): \(2 |\theta|\) bytes — producidos por el backward.
  • Pesos maestros (fp32): \(4 |\theta|\) bytes — necesarios para las actualizaciones de Adam.
  • Momentum de Adam (fp32): \(4 |\theta|\) bytes.
  • Varianza de Adam (fp32): \(4 |\theta|\) bytes.

Total: \(16 |\theta|\) bytes por GPU. Para 7B params: 112 GB. No cabe en una A100-80GB.

Qué hace ZeRO

Cada stage de ZeRO shardea una categoría más entre los \(N\) workers:

  • ZeRO-1: estado del optimizador (pesos maestros + momentum + varianza) — \(12 |\theta| / N\). Cada worker mantiene \(4 |\theta| / N\) de pesos maestros y junta lo que necesita para aplicar la actualización.
  • ZeRO-2: + gradientes shardeados. Cada worker mantiene \(2 |\theta| / N\) de gradientes. Tras el backward, un reduce-scatter distribuye un shard entero de gradientes totalmente reducidos a cada worker.
  • ZeRO-3: + parámetros shardeados. Cada worker mantiene \(2 |\theta| / N\) de pesos. Antes de cada capa de forward, un all-gather trae los pesos completos de esa capa, calcula, libera. Mismo patrón para backward.

Memoria bajo ZeRO-3 para 7B params en \(N=8\): \(16 \cdot 7\text{B} / 8 \approx 14\) GB por GPU. Cabe en una GPU de consumo de 24 GB.

El coste de comunicación crece

ZeRO-3 / FSDP duplica la comunicación frente a DDP:

  • DDP: 1 all-reduce por step sobre \(|\theta|\).
  • ZeRO-3 forward: \(L\) all-gathers sobre \(|\theta_\ell|\) — uno por capa.
  • ZeRO-3 backward: \(L\) all-gathers sobre \(|\theta_\ell|\) + 1 reduce-scatter sobre \(|\theta|\).

Volumen neto de comunicación ≈ \(3 \cdot |\theta|\) bytes por worker (gather forward + gather backward + reduce-scatter), frente a \(\approx 2 \cdot |\theta|\) de DDP. El overhead de factor 1,5 te compra 1/N de memoria. Merece la pena cuando la memoria era el cuello de botella.


FSDP — la implementación de ZeRO-3 en PyTorch

Fully Sharded Data Parallel en PyTorch es esencialmente ZeRO-3 con una API más limpia. El modelo conceptual es idéntico; la implementación tiene unas pocas características de calidad de vida:

  1. Buffers de parámetros planos (flat). FSDP agrupa parámetros en tensores 1D "planos" por unidad FSDP, lo que hace del all-gather una única comunicación contigua en lugar de una por parámetro.
  2. Prefetch. Mientras computa el forward de la capa \(\ell\), FSDP emite el all-gather para los parámetros de la capa \(\ell + 1\) de forma asíncrona. La comunicación se solapa con el cómputo.
  3. CPU offload. Opcional: los pesos maestros y el estado del optimizador pueden vivir en RAM de CPU, paginados a GPU solo cuando se necesitan. Cambia ancho de banda por memoria.
  4. auto_wrap_policy. Especifica qué submódulos se convierten en unidades FSDP. Una wrap policy de "cada bloque transformer" suele ser correcta.

El lab de "lectura anotada" para la Fase 35 (lab/03-megatron-fsdp-reading.md) recorre la lógica de prefetch de flat-param en torch/distributed/fsdp/_runtime_utils.py, que es la pieza menos obvia.


¿Cuándo es correcta la familia data-parallel?

Situación Recomendación
El modelo cabe en una GPU, quiero entrenar más rápido DDP
El modelo cabe en una GPU pero el estado del optimizador no ZeRO-1
El modelo no cabe ni en fp16 ZeRO-3 / FSDP
Inferencia (sin gradientes, sin estado del optimizador) Replicación estilo DDP vale hasta que el modelo supere una GPU; luego TP
Entrenamiento distribuido entre N nodos con ancho de banda limitado DDP + ZeRO-1 dentro de cada nodo; restringe ZeRO-3 a intra-nodo
Quiero depurar, no me importa el rendimiento Single-GPU. Empieza siempre single-GPU.

Para el grammar tutor de este currículo:

  • El modelo es microscópico — cabe en una calculadora. Distribuirlo es educativo, no necesario.
  • El lab 01 usa DDP entre 2 procesos CPU para enseñar el protocolo de cable: cómo funciona init_process_group, qué pinta tiene un all-reduce en torch.distributed, qué hace NCCL/gloo por debajo.
  • No corremos ZeRO-3 / FSDP sobre el grammar tutor — no hay nada que shardear.

El ejercicio a futuro: "si el vocabulario del grammar tutor creciera de 600 formas a 600k formas (inglés + español + francés + alemán + italiano + portugués)," ¿cuándo la propia tabla de embeddings supera una GPU? Respuesta: con \(d_{\text{model}} = 4096\), 600k tokens × 4096 × 4 bytes = 10 GB solo para la tabla de embeddings. Ahí es cuando shardear la tabla de embeddings (una variante específica de TP, ver theory/02-parallelism-flavors.md) se vuelve la primera cosa a la que recurrir.

Lo que esta fase NO cubre

  • Implementar ZeRO-3 / FSDP desde cero. FSDP de PyTorch son ~3000 LOC de concurrencia cuidadosa. Lo leemos; no lo reescribimos. El lab 01 de la Fase 35 implementa solo DDP, e incluso eso es mayormente un wrapper fino sobre torch.distributed.
  • Tuning de CPU offload. Un juego de tuning específico de FSDP. Fuera de alcance; mencionado para vocabulario.
  • Paralelismo 3D. DDP × TP × PP. Mencionado en 00-motivation.md; mecánica diferida a una hipotética Fase 41+.
  • Workers heterogéneos. Todos los workers son idénticos en esta fase. Cargas asimétricas (GPU + CPU mezclados, o GPU + TPU) son un área de investigación activa, irrelevante a este presupuesto.

Siguiente: theory/02-parallelism-flavors.md.