English · Español
Lab 03 — Annotated reading: Megatron-LM tensor-parallel + PyTorch FSDP¶
Goal: read two production-grade distributed-training files. Write annotated notes calling out design choices. Zero cloud cost.
Estimated time: 3–5 hours (reading-heavy).
Prereq: theory 01–04 done; labs 00–02 done. Borja can now read these files with the right vocabulary loaded.
What you produce¶
A directory experiments/35-reading-notes/ with two annotated reading notes:
megatron-tp-layers.md— annotated reading ofmegatron-lm/megatron/core/tensor_parallel/layers.py(specificallyColumnParallelLinearandRowParallelLinear).fsdp-prefetch.md— annotated reading oftorch/distributed/fsdp/_runtime_utils.pyflat-param prefetch logic.
Each note: ≥5 design choices called out with line citations, ≤300 words each (so it's a summary, not a transcription).
TODOs¶
Block A — clone Megatron-LM at a fixed commit¶
git clone https://github.com/NVIDIA/Megatron-LM.git /tmp/megatron-lm
cd /tmp/megatron-lm
git rev-parse HEAD > /home/overdrive/claude/lynx-cortex/experiments/35-reading-notes/megatron-sha.txt
Pin the SHA in the notes — Megatron's source changes; future Borja needs to know which version this lab read.
Block B — read Megatron's TP linear¶
Target file: megatron/core/tensor_parallel/layers.py. Focus classes: ColumnParallelLinear, RowParallelLinear. Suggested side-files for context: megatron/core/tensor_parallel/mappings.py (the _copy_to_tensor_model_parallel_region, _reduce_from_tensor_model_parallel_region, _gather_from_tensor_model_parallel_region primitives).
Write megatron-tp-layers.md calling out ≥5 design choices. Suggested candidates (you may pick others if you spot them):
- Why split the weight matrix into N column shards via
init.partial-style helpers — what does Megatron do that a naivetorch.nn.Parameter(W[:, rank * o//N:(rank+1) * o//N])gets wrong? - Async tensor parallel +
sequence_parallelflag — what issequence_parallel=Truedoing differently from plain TP? When is it worth the extra comm? gradient_accumulation_fusion— what computation is fused with the gradient buffer write? Why does it help?- The autograd
Functionfor the all-reduce (look for_ReduceFromModelParallelRegion) — why is this an explicit autogradFunction, not ann.Module? - CPU initialization vs GPU initialization (
use_cpu_initializationflag) — when do you want to initialize on CPU first? (Hint: memory.) async_tensor_model_parallel_allreduce— what does the async path overlap with?- Bias handling under TP — bias is replicated, not sharded. Why?
For each choice: 50–80 words. Cite the file path + line range (e.g., megatron/core/tensor_parallel/layers.py:L142-L168).
Block C — read PyTorch FSDP's flat-param prefetch¶
Target file: torch/distributed/fsdp/_runtime_utils.py. Focus on the _pre_forward and _post_forward paths, and the _prefetch_handle logic. Side-file: torch/distributed/fsdp/_flat_param.py for the flat-param structure.
Write fsdp-prefetch.md calling out ≥5 design choices. Suggested candidates:
- What is a "flat parameter"? Why does FSDP flatten and not keep individual parameters?
- Prefetch-after-shard-bound — at what moment does FSDP issue the all-gather for layer \(\ell + 1\)? What is overlapped?
USE_ORIG_PARAMS— flag for using original parameter objects instead of flat views. What does this break? What does it enable?CPU_OFFLOADand the move to/from CPU — when does FSDP page master weights between GPU and CPU? What's the latency budget for this?- Backward all-gather + reduce-scatter — what's the dependency? Why is the reduce-scatter not a "regular all-reduce"?
- The
_handles_prefetchedset — what state machine prevents double-prefetch or stale-prefetch? - Mixed-precision in FSDP —
MixedPrecision(param_dtype=..., reduce_dtype=..., buffer_dtype=...)— what does each control?
For each: 50–80 words + line citation.
Block D — synthesize: the diagram¶
In experiments/35-reading-notes/synthesis.md, draw two mermaid diagrams (~10 lines each):
- TP block layout under Megatron: a transformer block with the
column → GELU → rowMLP andcolumn-QKV → attn → row-outattention pattern. Mark the two all-reduces. Annotate "intra-node NVLink, ~600 GB/s" on the comm edges. - FSDP forward timeline: layer compute on the bottom track, prefetch all-gathers on the top track, showing the overlap.
The diagrams are mermaid (text-editable, version-controllable). Commit them.
Block E — connect-the-dots paragraph¶
End each note with a "connect-the-dots" paragraph:
megatron-tp-layers.md: how does what you just read explain the lab 02 slowdown on the grammar tutor? (Hint: the grammar tutor's \(d_{\text{model}}\) is small, so the per-all-reduce volume is tiny relative to per-token compute, so comm dominates.)fsdp-prefetch.md: when would FSDP be the right choice for the grammar tutor's training? (Hint: never at current vocab size; sometime around \(|\theta| \approx 1\text{B}\).)
Constraints¶
- No code rewrites. Don't try to "improve" Megatron or FSDP. The exercise is reading, not refactoring.
- Cite line ranges, not full snippets. The repo's SHA + line range is enough to reconstruct what you read. Don't paste 200 lines of source into your notes.
- ≤300 words per note. If it's getting longer, you're transcribing, not summarizing.
- Mermaid diagrams only. No PNGs from drawing tools — keep the diagrams diff-able.
- Zero cloud cost. This is reading. Local laptop only.
Stop conditions¶
You're done when:
experiments/35-reading-notes/{megatron-tp-layers.md, fsdp-prefetch.md, synthesis.md, megatron-sha.txt}exist.- Each
.mdhas ≥5 design-choice bullets with line citations and ≤300 words. synthesis.mdhas two mermaid diagrams (TP block layout + FSDP timeline).- Each note has a closing "connect-the-dots" paragraph tying back to lab 02 or to the grammar tutor's training profile.
- You can answer, from memory: "what does
gradient_accumulation_fusiondo" and "when does FSDP prefetch layer \(\ell+1\)".
Hint of last resort¶
If Megatron's source seems impenetrable: start with the module docstrings at the top of layers.py. NVIDIA has improved these recently; they explain the column/row split with diagrams. Then re-read the code with the docstrings as a map.
For FSDP: PyTorch's docs/source/fsdp.rst (in the PyTorch repo) is the official narrative. Read it first, then read _runtime_utils.py. The narrative tells you what to look for in the code.
When to consult solutions/¶
After committing the notes. Solution lives in solutions/03-reading-ref.md — written at phase open. The solution is a reference set of design-choice picks with citations, not "the right" answer. Borja's picks may legitimately differ; the comparison should be "did I miss any of these?" not "did I match exactly?".
Next phase: docs/phase-36-frontier-architectures/.