Skip to content

English · Español

01 — Data parallel, ZeRO, and FSDP

🇪🇸 La familia "cada GPU tiene una réplica (o un trozo) del modelo y un batch distinto". DDP es la versión ingenua: réplica completa, all-reduce de gradientes. ZeRO-½/3 va recortando estado por GPU (optimizador → gradientes → parámetros). FSDP es la encarnación moderna de ZeRO-3 en PyTorch. La consigna: cuanto menos por GPU, más comunicación.

In 00-motivation.md we said: distribute either because the model doesn't fit, or because compute takes too long. The data-parallel family primarily attacks the second problem — many GPUs share the work of a batch — but ZeRO-⅔ and FSDP fold in memory-pressure relief by sharding state across the same workers.

This page is the mechanics, the comm patterns, and the cost tradeoffs.


DDP — the baseline

In DistributedDataParallel (DDP), every GPU holds an identical copy of the model parameters \(\theta\), the optimizer state, and the gradients buffer. Each step:

  1. The global batch of size \(B\) is split into per-worker shards of size \(B/N\) (where \(N\) = number of workers).
  2. Each worker does forward + backward on its shard, producing local gradients \(g_i\).
  3. Workers run an all-reduce over the \(g_i\)'s, so every worker ends with \(\bar{g} = (1/N) \sum_i g_i\).
  4. Each worker applies the same optimizer update with \(\bar{g}\). Because they started identical and saw identical \(\bar{g}\), they remain identical. Determinism preserved.

The comm step (3) is the only inter-worker communication per training step. Implementation detail: framework runs all-reduce bucketed — fires the reduce on each parameter group as soon as its gradient is ready, overlapping comm with the remaining backward.

Comm volume per step

A ring all-reduce of \(|\theta|\) parameters in fp32 sends \(2 \cdot (N-1)/N \cdot 4 \cdot |\theta|\) bytes per worker. For large \(N\), this approaches \(8 \cdot |\theta|\) bytes per worker, independent of \(N\). That independence is what makes DDP scale gracefully — until you hit network saturation.

For MiniGPT-grammar (the Phase 17 grammar-tutor model — ~500k params), \(8 \cdot 500\text{k} = 4\) MB per step. On a 10 Gbps link that's ~3 ms of comm. At 50 ms/step compute, that's 6% overhead. Fine.

For a 7B-parameter model, that's \(8 \cdot 7\text{B} = 56\) GB per step. Even on 100 Gbps InfiniBand (~12 GB/s effective), that's 4.7 seconds of comm per step. At that scale, DDP alone is no longer enough. You need to either reduce comm volume per worker (ZeRO-½/3 reduce the per-worker state) or change topology (hierarchical all-reduce across nodes vs intra-node).

Determinism caveat

DDP is bit-exactly deterministic across workers (same seed → same output), but not bit-exactly identical to single-GPU training of the same global batch. Floating-point summation is not associative; the order in which gradient elements are summed during all-reduce differs from the order a single GPU would sum them. Acceptance criterion for "DDP equivalent to single-GPU" is "≤ 1e-5 logit drift", not "byte-equivalent."


ZeRO — sharding the redundancy

DDP wastes memory: every GPU holds a full copy of everything. The ZeRO family (Zero Redundancy Optimizer, from Microsoft DeepSpeed) observes that some of that redundancy is unnecessary if you're willing to do extra comm.

What's in each GPU's memory under DDP

For a model with \(|\theta|\) parameters using mixed-precision Adam:

  • Weights (fp16): \(2 |\theta|\) bytes — needed for forward + backward.
  • Gradients (fp16): \(2 |\theta|\) bytes — produced by backward.
  • Master weights (fp32): \(4 |\theta|\) bytes — needed for Adam updates.
  • Adam momentum (fp32): \(4 |\theta|\) bytes.
  • Adam variance (fp32): \(4 |\theta|\) bytes.

Total: \(16 |\theta|\) bytes per GPU. For 7B params: 112 GB. Doesn't fit on a single A100-80GB.

What ZeRO does

Each ZeRO stage shards one more category across the \(N\) workers:

  • ZeRO-1: optimizer state (master weights + momentum + variance) — \(12 |\theta| / N\). Each worker keeps \(4 |\theta| / N\) of master weights and gathers what it needs to apply the update.
  • ZeRO-2: + gradients sharded. Each worker keeps \(2 |\theta| / N\) of gradients. After backward, a reduce-scatter distributes one shard's worth of fully-reduced gradients to each worker.
  • ZeRO-3: + parameters sharded. Each worker keeps \(2 |\theta| / N\) of weights. Before each forward layer, an all-gather fetches the full weights for that layer, computes, frees. Same pattern for backward.

Memory under ZeRO-3 for 7B params on \(N=8\): \(16 \cdot 7\text{B} / 8 \approx 14\) GB per GPU. Fits on a 24 GB consumer GPU.

Comm cost grows

ZeRO-3 / FSDP doubles the comm vs DDP:

  • DDP: 1 all-reduce per step over \(|\theta|\).
  • ZeRO-3 forward: \(L\) all-gathers over \(|\theta_\ell|\) — one per layer.
  • ZeRO-3 backward: \(L\) all-gathers over \(|\theta_\ell|\) + 1 reduce-scatter over \(|\theta|\).

Net comm volume ≈ \(3 \cdot |\theta|\) bytes per worker (gather forward + gather backward + reduce-scatter), vs DDP's \(\approx 2 \cdot |\theta|\). The factor-1.5 overhead buys you 1/N memory. Worth it when memory was the bottleneck.


FSDP — PyTorch's implementation of ZeRO-3

Fully Sharded Data Parallel in PyTorch is essentially ZeRO-3 with a cleaner API. The conceptual model is identical; the implementation has a few quality-of-life features:

  1. Flat parameter buffers. FSDP groups parameters into "flat" 1D tensors per FSDP unit, which makes the all-gather a single contiguous comm instead of one per parameter.
  2. Prefetch. While computing layer \(\ell\)'s forward, FSDP issues the all-gather for layer \(\ell + 1\)'s parameters asynchronously. Comm overlaps compute.
  3. CPU offload. Optional: master weights and optimizer state can live in CPU RAM, paged to GPU only when needed. Trades bandwidth for memory.
  4. auto_wrap_policy. Specifies which submodules become FSDP units. A wrap policy of "every transformer block" is usually correct.

The "annotated reading" lab for Phase 35 (lab/03-megatron-fsdp-reading.md) walks torch/distributed/fsdp/_runtime_utils.py's flat-param prefetch logic, which is the most non-obvious piece.


When is data-parallel-family right?

Situation Recommendation
Model fits on one GPU, want faster training DDP
Model fits on one GPU but optimizer state doesn't ZeRO-1
Model doesn't fit even in fp16 ZeRO-3 / FSDP
Inference (no gradients, no optimizer state) DDP-style replicate is fine until model exceeds one GPU; then TP
Distributed training across N nodes with limited bandwidth DDP + ZeRO-1 inside each node; restrict ZeRO-3 to intra-node
Want to debug, don't care about throughput Single-GPU. Always start single-GPU.

For the grammar tutor in this curriculum:

  • The model is microscopic — fits on a calculator. Distributing it is educational, not necessary.
  • Lab 01 uses DDP across 2 CPU processes to teach the wire protocol: how does init_process_group work, what does an all-reduce look like in torch.distributed, what does NCCL/gloo do under the hood.
  • We do not run ZeRO-3 / FSDP on the grammar tutor — there's nothing to shard.

The forward-looking exercise: "if the grammar tutor's vocabulary grew from 600 forms to 600k forms (English + Spanish + French + German + Italian + Portuguese)," when does the embedding table itself exceed one GPU? Answer: at a \(d_{\text{model}} = 4096\), 600k tokens × 4096 × 4 bytes = 10 GB just for the embedding table. That's when sharding the embedding table (a specific TP variant, see theory/02-parallelism-flavors.md) becomes the first thing you reach for.

What this phase does NOT cover

  • Implementing ZeRO-3 / FSDP from scratch. PyTorch's FSDP is ~3000 LOC of careful concurrency. We read it; we do not rewrite it. Phase 35's lab 01 implements DDP only, and even that is mostly a thin wrapper over torch.distributed.
  • CPU offload tuning. A FSDP-specific tuning game. Out of scope; mentioned for vocabulary.
  • 3D parallelism. DDP × TP × PP. Mentioned in 00-motivation.md; mechanics deferred to a hypothetical Phase 41+.
  • Heterogeneous workers. All workers identical in this phase. Asymmetric workloads (GPU + CPU mixed, or GPU + TPU) are an active research area, irrelevant at this budget.

Next: theory/02-parallelism-flavors.md.