Skip to content

English · Español

03 — torch.compile and Distributed (Survey)

🇪🇸 Esta página cubre dos temas que merecerían cada uno una fase completa pero que en Fase 25 son un survey honesto: cómo funciona torch.compile (Dynamo → AOTAutograd → Inductor) y cuáles son los cuatro patrones canónicos de distributed (DDP, FSDP, tensor-parallel, pipeline-parallel). Hands-on con compile, lectura para distributed. Phase 33 vuelve a compile; Phase 35 hace distributed real.

This page is intentionally a survey. The Phase 25 lab will use torch.compile (run it, dump Inductor output, identify the fusion); it won't build a compile pipeline. Distributed is read-only: the lab writes a 1-page README distinguishing the four patterns, no torch.distributed code beyond init_process_group hello-world.

The full hands-on of compile is Phase 33 (serving) and of distributed is Phase 35.


Part A: torch.compile

What it is

torch.compile(model) returns an optimized version of the model. Subsequent calls trace the forward (and backward) into a graph, optimize the graph (fusion, layout choice, kernel selection), emit Triton (for GPU) or C++ (for CPU) kernels, and run those kernels in place of the per-op dispatch sequence.

Schematically:

model(x)                                 # eager: 100 dispatches per forward
model_c = torch.compile(model)
model_c(x)
[TorchDynamo]      Python bytecode → FX graph
[AOTAutograd]      FX graph → joint forward+backward FX graph
[Inductor]         FX graph → Triton/C++ kernel files
[runtime]          Loads the compiled kernels; runs them in place of eager

After the first call (compile time: ~seconds-to-minutes), subsequent calls run the compiled kernels — typically 1.5–3× faster than eager for inference, 1.2–2× faster for training.

Stage 1: TorchDynamo (Python → FX)

Dynamo is a Python-bytecode tracer. It runs your model's forward() once symbolically — propagating fake tensors with shape/dtype metadata — and records every torch operation into a FX graph (PyTorch's intermediate representation).

If Dynamo can't trace a part of the code (e.g., a data-dependent Python branch on a tensor value), it inserts a graph break: emits one graph for the prefix, runs the offending Python in eager mode, then traces the suffix. Graph breaks reduce optimization opportunities.

Common graph-break causes:

  • if x.sum() > 0: (tensor → Python control flow).
  • Calls to non-tracable libraries.
  • print(...) with a tensor argument.
  • Mutating Python data structures.

Verbose graph-break diagnosis: TORCHDYNAMO_VERBOSE=1 python script.py.

Stage 2: AOTAutograd (joint forward+backward FX)

The forward FX graph from Dynamo is fed to AOTAutograd, which:

  • Traces the backward pass symbolically (just like the autograd engine would at runtime, but ahead-of-time).
  • Produces a joint FX graph with both forward and backward nodes.
  • Decomposes high-level ops (like linear) into their primitive ops (matmul, add) for finer-grained Inductor optimization.

For training, this is where the autograd graph "fuses with" the forward graph — saving the per-step graph-capture cost.

Stage 3: Inductor (FX → Triton/C++)

Inductor is PyTorch's compiler backend. It takes the FX graph and emits actual kernel code:

  • CUDA path: emits Triton kernels (the same Triton language from Phase 24).
  • CPU path: emits C++ with OpenMP / vectorized intrinsics.

Inductor does:

  • Fusion: combines adjacent elementwise ops + reductions into one kernel.
  • Layout selection: chooses memory layout (contiguous vs strided) per op.
  • Loop tiling: picks tile sizes for SMEM / cache.
  • Autotuning: optionally sweeps tile sizes (mode="max-autotune").

The output is stored in /tmp/torchinductor_<user>/<hash>/. With TORCH_LOGS=output_code, Inductor prints the generated kernel(s) to stderr. Reading this is enlightening — it's just a Triton file, like the one Borja wrote by hand in Phase 24.

Compile modes

torch.compile(model, mode="default")           # balanced; ~1 minute compile
torch.compile(model, mode="reduce-overhead")   # uses CUDA Graphs; lowest latency
torch.compile(model, mode="max-autotune")      # exhaustive tile-size sweep; slow compile, fastest run

For grammar MiniGPT inference: reduce-overhead is usually right. Lab 03 tries each.

What Inductor generates for nn.Linear + softmax

Manually you'd write:

y = lm_head(x)         # nn.Linear(64, 600) → cuBLAS gemm
p = F.softmax(y, -1)   # custom softmax kernel

After torch.compile, Inductor might emit:

  • One cuBLAS gemm call for the matmul (it doesn't fuse matmuls with elementwise typically).
  • One fused Triton kernel for the softmax (max + exp + sum + normalize all in one launch).

For the lab, dump this with TORCH_LOGS=output_code and identify the softmax fusion in the Inductor source. The kernel is ~30 lines of Triton — directly comparable to Borja's hand-written one from Phase 24.

When compile helps and when it doesn't

Case Speedup
Many small elementwise ops between matmuls 2–5× (fusion eliminates intermediate tensors)
One big matmul dominates time ~1× (cuBLAS is already optimal)
Graph breaks every few ops Marginal (overhead-only)
Model with torch.jit.script already applied Possibly negative (compile re-traces, may regress)

Phase 33's serving lab will measure compile gains on the full grammar MiniGPT.

Part B: Distributed (Survey)

This is a concepts survey. Phase 35 builds these for real. Here we name them, describe them, and place them on a 2D axis of what gets split and how communication scales.

DDP: Data-Parallel

GPU 0:  model (full copy) + batch slice 0
GPU 1:  model (full copy) + batch slice 1
GPU 2:  model (full copy) + batch slice 2
GPU 3:  model (full copy) + batch slice 3

Each GPU holds a full copy of the model. Different GPUs see different batch slices. Forward and backward are independent per GPU; after backward, gradients are all-reduced across GPUs (averaged) so all copies stay in sync.

PyTorch: torch.nn.parallel.DistributedDataParallel(model).

Communication: O(model size) per step (the all-reduce of gradients). Scales well to 8 GPUs; struggles past 64 because each GPU still holds a full model.

FSDP: Fully-Sharded Data-Parallel

Same as DDP, but each GPU holds only a shard of each layer's parameters. Before a layer runs, the shard is gathered from other GPUs (allgather); after the layer, the shard is released.

  • Reduces per-GPU memory by N× (where N is the world size).
  • Increases communication (allgather + reduce-scatter per layer vs one all-reduce per step).
  • Necessary for models that don't fit on one GPU.

PyTorch: torch.distributed.fsdp.FullyShardedDataParallel.

Tensor-Parallel (TP)

A single matmul is split across GPUs. For \(Y = X W\) with \(W\) a 4096×4096 matrix on 2 GPUs:

GPU 0:  W_left  (4096 × 2048)
GPU 1:  W_right (4096 × 2048)

X is broadcast to both GPUs.
Y_left  = X @ W_left   on GPU 0   (output shape (B, 2048))
Y_right = X @ W_right  on GPU 1   (output shape (B, 2048))
Y = concat(Y_left, Y_right)

For attention: split heads across GPUs. For FFN: split the hidden dim.

Communication: per-layer (each forward through a TP-split layer requires an allreduce or allgather of the partial outputs).

Library: Megatron-LM, vLLM, or hand-written. PyTorch's tensor_parallel is alpha-level.

Pipeline-Parallel (PP)

Split the model depth. Layers 1–8 on GPU 0; layers 9–16 on GPU 1; layers 17–24 on GPU 2; layers 25–32 on GPU 3.

For one forward: GPU 0 runs layers 1–8, passes activations to GPU 1, runs layers 9–16, etc. Naive PP underutilizes GPUs (only one is active at a time). The fix: micro-batching — split a batch into K micro-batches, have GPU 0 process micro-batch 2 while GPU 1 processes micro-batch 1, etc. This is the "pipeline schedule" / "bubble" of pipeline parallelism.

Libraries: PyTorch pippy, DeepSpeed pipeline.

Choosing among the four

Model fits on 1 GPU? Pattern
Yes DDP (simplest, scales to 8-ish GPUs)
No, but per-layer fits FSDP (shard params) or TP (split the matmuls)
No, and per-layer doesn't fit PP (split the depth) or 3D-parallel (combine FSDP + TP + PP)

For Phase 35's lab, the grammar MiniGPT fits on Borja's laptop — distributed is not motivated by need. Phase 35 uses a slightly larger model (or simulates multi-GPU with gloo backend on CPU) to demonstrate the patterns.

What you should now be able to do

  1. Run torch.compile(model) and dump the Inductor output.
  2. Read an Inductor-generated Triton kernel and identify the fusion.
  3. Distinguish DDP, FSDP, TP, PP — what's split, what's communicated.
  4. Predict which distributed pattern applies to a given model size and GPU count.
  5. Recognize the limitations: compile graph breaks, FSDP communication cost, TP latency overhead, PP bubble.

What this page does NOT cover

  • Compile failures and their fixes. Phase 33 dedicates time to debugging compile.
  • CUDA Graphs. Phase 33.
  • Pipeline scheduling algorithms (1F1B, interleaved 1F1B). Phase 35.
  • NCCL collective primitives in detail. Phase 35.
  • torch.func functional transforms. Phase 38 maybe.

Next: lab/00-dispatcher-trace.md — instrument a linear(x, W, b) call and read the dispatcher log.