Skip to content

English · Español

Lab 01 — DPO on the grammar-tutor

🇪🇸 Implementación de DPO en ~50 líneas sobre el tutor LoRA de la Fase 28. Mide la tasa de victoria contra el baseline SFT en un test set retenido.

Goal

Implement Direct Preference Optimization in ~50 lines, fine-tune the Phase 28 LoRA grammar-tutor on the same 200-pair dataset from Lab 00, and measure win rate against the SFT baseline on a held-out 50-pair test set. Target: > 55% win rate.

Why this is short

DPO has no reward model, no rollouts, no value function, no advantage estimation (theory chapter 04). The training loop is structurally identical to SFT — just with a different loss. Once you have the data loader from Lab 00, the only new code is ~50 lines for the DPO loss and the eval.

The loss

From theory chapter 04:

\[ \mathcal{L}_{\text{DPO}}(\theta) = -\,\mathbb{E}_{(x, y_w, y_l)}\!\left[ \log \sigma\!\left( \beta\,\Delta_w - \beta\,\Delta_l \right) \right] \]

where the per-pair log-ratio differences are

\[ \Delta_w = \log \pi_\theta(y_w \mid x) - \log \pi_{\text{ref}}(y_w \mid x), \qquad \Delta_l = \log \pi_\theta(y_l \mid x) - \log \pi_{\text{ref}}(y_l \mid x). \]

The hyperparameter \(\beta = 0.1\) for this lab.

What log π(y | x) means in code

For a sequence \(y = (y_1, \dots, y_T)\) and a prompt \(x\):

\[ \log \pi_\theta(y \mid x) = \sum_{t=1}^{T} \log \pi_\theta(y_t \mid x, y_{<t}) \]

i.e., the sum of the per-token log-probs of the response tokens only (not the prompt tokens). The implementation sums log-probs over a response-mask.

Implementation

# src/x3_rlhf/dpo.py
# DPO loss + training step. Total ~50 lines including utilities.

from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F


def sequence_logprob(
    model: nn.Module,
    input_ids: torch.Tensor,        # (B, T)
    response_mask: torch.Tensor,    # (B, T), 1 for response tokens, 0 elsewhere
) -> torch.Tensor:
    """Sum of log-probs over response tokens. Shape: (B,)."""
    logits = model(input_ids).logits[:, :-1, :]            # (B, T-1, V)
    targets = input_ids[:, 1:]                              # (B, T-1)
    logp = F.log_softmax(logits, dim=-1)
    token_logp = logp.gather(2, targets.unsqueeze(-1)).squeeze(-1)  # (B, T-1)
    mask = response_mask[:, 1:].float()
    return (token_logp * mask).sum(dim=-1)                  # (B,)


def dpo_loss(
    policy: nn.Module,
    reference: nn.Module,
    batch: dict[str, torch.Tensor],
    beta: float = 0.1,
) -> tuple[torch.Tensor, dict[str, float]]:
    """DPO loss on (chosen, rejected) pairs. Reference is frozen."""
    pi_w  = sequence_logprob(policy,    batch["chosen_ids"],   batch["chosen_mask"])
    pi_l  = sequence_logprob(policy,    batch["rejected_ids"], batch["rejected_mask"])
    with torch.no_grad():
        ref_w = sequence_logprob(reference, batch["chosen_ids"],   batch["chosen_mask"])
        ref_l = sequence_logprob(reference, batch["rejected_ids"], batch["rejected_mask"])

    delta_w = pi_w - ref_w     # implicit reward for chosen / β
    delta_l = pi_l - ref_l     # implicit reward for rejected / β
    logits  = beta * (delta_w - delta_l)
    loss    = -F.logsigmoid(logits).mean()

    # Diagnostics
    metrics = {
        "loss":           loss.item(),
        "implicit_r_w":   (beta * delta_w).mean().item(),
        "implicit_r_l":   (beta * delta_l).mean().item(),
        "pair_accuracy":  (logits > 0).float().mean().item(),
        "kl_chosen":      (pi_w - ref_w).mean().item(),
        "kl_rejected":    (pi_l - ref_l).mean().item(),
    }
    return loss, metrics

Training script (~30 more lines, in scripts/train_dpo.py):

import copy, torch
from x3_rlhf.dpo import dpo_loss
from lynx_cortex.utils import seed_everything, save_manifest
from lynx_cortex.phase17 import load_sft_model
from lynx_cortex.phase28 import attach_lora
from lynx_cortex.data import load_pref_yaml, make_loader

seed_everything(42)

# Reference = frozen SFT model. Policy = SFT + trainable LoRA.
reference = load_sft_model("checkpoints/phase28-lora.pt").eval()
for p in reference.parameters():
    p.requires_grad = False

policy = attach_lora(load_sft_model("checkpoints/phase28-lora.pt"),
                     r=8, alpha=16, dropout=0.05)
trainable = [p for p in policy.parameters() if p.requires_grad]

train_pairs, eval_pairs = load_pref_yaml(
    "data/preferences/grammar_v0.yaml", split=(150, 50)
)
train_loader = make_loader(train_pairs, batch_size=8, shuffle=True)

opt = torch.optim.AdamW(trainable, lr=1e-4, weight_decay=0.0)

for epoch in range(3):
    for batch in train_loader:
        loss, m = dpo_loss(policy, reference, batch, beta=0.1)
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable, max_norm=1.0)
        opt.step()
        print(f"loss={m['loss']:.4f} acc={m['pair_accuracy']:.3f} "
              f"kl_w={m['kl_chosen']:+.3f} kl_l={m['kl_rejected']:+.3f}")

torch.save({k: v for k, v in policy.state_dict().items() if "lora" in k},
           "checkpoints/dpo-lora-v0.pt")
save_manifest("experiments/2026-05-23-dpo-v0/manifest.json",
              {"seed": 42, "beta": 0.1, "epochs": 3, "lr": 1e-4})

That is the entire DPO training loop. Compare to what PPO-for-language requires (theory chapter 03).

Evaluation: win rate vs SFT baseline

For each of the 50 held-out prompts, generate one greedy response from the SFT baseline and one from the DPO model. A judge (here, the Lab 00 reward model) labels which is better; report fraction where DPO wins.

# scripts/eval_winrate.py
from x3_rlhf.reward_model import RewardModel
from lynx_cortex.phase17 import load_sft_model, generate
from lynx_cortex.phase28 import attach_lora, load_lora

sft     = load_sft_model("checkpoints/phase28-lora.pt").eval()
dpo     = attach_lora(load_sft_model("checkpoints/phase28-lora.pt"))
load_lora(dpo, "checkpoints/dpo-lora-v0.pt")
dpo.eval()

rm = RewardModel(load_sft_model("checkpoints/phase28-lora.pt"))
rm.head.load_state_dict(torch.load("checkpoints/rm-v0.pt"))
rm.eval()

_, eval_pairs = load_pref_yaml("data/preferences/grammar_v0.yaml", split=(150, 50))

wins = 0
for pair in eval_pairs:
    x = pair["prompt"]
    y_sft = generate(sft, x, max_new=32, temperature=0.0)
    y_dpo = generate(dpo, x, max_new=32, temperature=0.0)
    r_sft = rm.score(x, y_sft)
    r_dpo = rm.score(x, y_dpo)
    wins += int(r_dpo > r_sft)

print(f"DPO win rate: {wins / len(eval_pairs):.3f}")

Expected results

Metric Target Notes
DPO train loss after 3 epochs < 0.50 Starts at \(\log 2 \approx 0.693\)
Train pair accuracy (final epoch) > 0.85
KL on chosen (\(\Delta_w\), mean) -0.5 to +1.5 Should be positive — policy preferring chosen
KL on rejected (\(\Delta_l\), mean) -2.0 to -0.5 Should be negative — policy disprefers rejected
DPO win rate vs SFT (held-out 50 pairs) > 0.55 The headline metric

If DPO win rate is ≤ 0.50, debug:

  1. Did the loss actually decrease? (Common bug: response mask off-by-one with shifted-tokens.)
  2. Are \(\Delta_w\) and \(\Delta_l\) moving in the right directions?
  3. Is the reference model actually frozen and in eval mode?

What to look at

  1. Loss curve. Should drop from \(\sim 0.69\) (random) to \(\sim 0.3\)\(0.4\) by epoch 3. Flat or rising means a bug.
  2. Implicit reward gap \(\beta(\Delta_w - \Delta_l)\). Tracks the BT logit. Should rise; if it saturates fast, lower \(\beta\) or check for label noise.
  3. KL trajectory. \(\mathbb{E}[\Delta_w]\) and \(\mathbb{E}[\Delta_l]\) jointly. If both drift far negative, the policy is forgetting language — raise \(\beta\).

Things to break

  1. Set \(\beta = 0\). The loss becomes \(-\log\sigma(0) = \log 2\) regardless of the policy; gradient is zero. Verify training stalls.
  2. Set \(\beta = 10\). Sigmoid saturates almost immediately; gradient vanishes for easy pairs; learning stalls in a different way.
  3. Swap chosen/rejected in 10% of training pairs (label noise). Watch train accuracy plateau lower; DPO is more sensitive to label noise than SFT because it relies on the ranking.
  4. Use the policy itself as the reference (\(\pi_{\text{ref}} \leftarrow \pi_\theta\) at every step). The implicit reward terms become zero; loss is constant. Confirms the reference-model role.

DoD

  • train_dpo.py runs end-to-end on CPU in < 15 minutes.
  • Loss decreases monotonically.
  • DPO win rate vs SFT > 0.55.
  • All four diagnostics (\(\Delta_w\), \(\Delta_l\), pair accuracy, loss) saved per epoch.
  • Manifest persisted with seed, \(\beta\), lr, win rate.
  • Two "break it" exercises attempted with one-paragraph writeups.