Skip to content

English · Español

Lab 03 — Annotated reading: Megatron-LM tensor-parallel + PyTorch FSDP

Goal: read two production-grade distributed-training files. Write annotated notes calling out design choices. Zero cloud cost.

Estimated time: 3–5 hours (reading-heavy).

Prereq: theory 01–04 done; labs 00–02 done. Borja can now read these files with the right vocabulary loaded.


What you produce

A directory experiments/35-reading-notes/ with two annotated reading notes:

  • megatron-tp-layers.md — annotated reading of megatron-lm/megatron/core/tensor_parallel/layers.py (specifically ColumnParallelLinear and RowParallelLinear).
  • fsdp-prefetch.md — annotated reading of torch/distributed/fsdp/_runtime_utils.py flat-param prefetch logic.

Each note: ≥5 design choices called out with line citations, ≤300 words each (so it's a summary, not a transcription).

TODOs

Block A — clone Megatron-LM at a fixed commit

git clone https://github.com/NVIDIA/Megatron-LM.git /tmp/megatron-lm
cd /tmp/megatron-lm
git rev-parse HEAD > /home/overdrive/claude/lynx-cortex/experiments/35-reading-notes/megatron-sha.txt

Pin the SHA in the notes — Megatron's source changes; future Borja needs to know which version this lab read.

Block B — read Megatron's TP linear

Target file: megatron/core/tensor_parallel/layers.py. Focus classes: ColumnParallelLinear, RowParallelLinear. Suggested side-files for context: megatron/core/tensor_parallel/mappings.py (the _copy_to_tensor_model_parallel_region, _reduce_from_tensor_model_parallel_region, _gather_from_tensor_model_parallel_region primitives).

Write megatron-tp-layers.md calling out ≥5 design choices. Suggested candidates (you may pick others if you spot them):

  1. Why split the weight matrix into N column shards via init.partial-style helpers — what does Megatron do that a naive torch.nn.Parameter(W[:, rank * o//N:(rank+1) * o//N]) gets wrong?
  2. Async tensor parallel + sequence_parallel flag — what is sequence_parallel=True doing differently from plain TP? When is it worth the extra comm?
  3. gradient_accumulation_fusion — what computation is fused with the gradient buffer write? Why does it help?
  4. The autograd Function for the all-reduce (look for _ReduceFromModelParallelRegion) — why is this an explicit autograd Function, not a nn.Module?
  5. CPU initialization vs GPU initialization (use_cpu_initialization flag) — when do you want to initialize on CPU first? (Hint: memory.)
  6. async_tensor_model_parallel_allreduce — what does the async path overlap with?
  7. Bias handling under TP — bias is replicated, not sharded. Why?

For each choice: 50–80 words. Cite the file path + line range (e.g., megatron/core/tensor_parallel/layers.py:L142-L168).

Block C — read PyTorch FSDP's flat-param prefetch

Target file: torch/distributed/fsdp/_runtime_utils.py. Focus on the _pre_forward and _post_forward paths, and the _prefetch_handle logic. Side-file: torch/distributed/fsdp/_flat_param.py for the flat-param structure.

Write fsdp-prefetch.md calling out ≥5 design choices. Suggested candidates:

  1. What is a "flat parameter"? Why does FSDP flatten and not keep individual parameters?
  2. Prefetch-after-shard-bound — at what moment does FSDP issue the all-gather for layer \(\ell + 1\)? What is overlapped?
  3. USE_ORIG_PARAMS — flag for using original parameter objects instead of flat views. What does this break? What does it enable?
  4. CPU_OFFLOAD and the move to/from CPU — when does FSDP page master weights between GPU and CPU? What's the latency budget for this?
  5. Backward all-gather + reduce-scatter — what's the dependency? Why is the reduce-scatter not a "regular all-reduce"?
  6. The _handles_prefetched set — what state machine prevents double-prefetch or stale-prefetch?
  7. Mixed-precision in FSDPMixedPrecision(param_dtype=..., reduce_dtype=..., buffer_dtype=...) — what does each control?

For each: 50–80 words + line citation.

Block D — synthesize: the diagram

In experiments/35-reading-notes/synthesis.md, draw two mermaid diagrams (~10 lines each):

  • TP block layout under Megatron: a transformer block with the column → GELU → row MLP and column-QKV → attn → row-out attention pattern. Mark the two all-reduces. Annotate "intra-node NVLink, ~600 GB/s" on the comm edges.
  • FSDP forward timeline: layer compute on the bottom track, prefetch all-gathers on the top track, showing the overlap.

The diagrams are mermaid (text-editable, version-controllable). Commit them.

Block E — connect-the-dots paragraph

End each note with a "connect-the-dots" paragraph:

  • megatron-tp-layers.md: how does what you just read explain the lab 02 slowdown on the grammar tutor? (Hint: the grammar tutor's \(d_{\text{model}}\) is small, so the per-all-reduce volume is tiny relative to per-token compute, so comm dominates.)
  • fsdp-prefetch.md: when would FSDP be the right choice for the grammar tutor's training? (Hint: never at current vocab size; sometime around \(|\theta| \approx 1\text{B}\).)

Constraints

  • No code rewrites. Don't try to "improve" Megatron or FSDP. The exercise is reading, not refactoring.
  • Cite line ranges, not full snippets. The repo's SHA + line range is enough to reconstruct what you read. Don't paste 200 lines of source into your notes.
  • ≤300 words per note. If it's getting longer, you're transcribing, not summarizing.
  • Mermaid diagrams only. No PNGs from drawing tools — keep the diagrams diff-able.
  • Zero cloud cost. This is reading. Local laptop only.

Stop conditions

You're done when:

  1. experiments/35-reading-notes/{megatron-tp-layers.md, fsdp-prefetch.md, synthesis.md, megatron-sha.txt} exist.
  2. Each .md has ≥5 design-choice bullets with line citations and ≤300 words.
  3. synthesis.md has two mermaid diagrams (TP block layout + FSDP timeline).
  4. Each note has a closing "connect-the-dots" paragraph tying back to lab 02 or to the grammar tutor's training profile.
  5. You can answer, from memory: "what does gradient_accumulation_fusion do" and "when does FSDP prefetch layer \(\ell+1\)".

Hint of last resort

If Megatron's source seems impenetrable: start with the module docstrings at the top of layers.py. NVIDIA has improved these recently; they explain the column/row split with diagrams. Then re-read the code with the docstrings as a map.

For FSDP: PyTorch's docs/source/fsdp.rst (in the PyTorch repo) is the official narrative. Read it first, then read _runtime_utils.py. The narrative tells you what to look for in the code.

When to consult solutions/

After committing the notes. Solution lives in solutions/03-reading-ref.md — written at phase open. The solution is a reference set of design-choice picks with citations, not "the right" answer. Borja's picks may legitimately differ; the comparison should be "did I miss any of these?" not "did I match exactly?".


Next phase: docs/phase-36-frontier-architectures/.