Skip to content

English · Español

01 — Supervised Fine-Tuning and Catastrophic Forgetting

🇪🇸 SFT es solo lenguaje habitual con etiquetas. Lo no-trivial es por qué actualizar todos los pesos puede borrar habilidades que el modelo tenía antes. La explicación es geométrica: cada paso de gradiente desplaza los pesos en una dirección, y muchas direcciones útiles del modelo original no están alineadas con el objetivo del fine-tune.


SFT: the easy half

Supervised fine-tuning is the simplest possible setup. Given:

  • A pretrained model f_θ(x) mapping inputs to next-token logits.
  • A dataset D = {(x_i, y_i)} of (prompt, completion) pairs.

The objective is standard cross-entropy:

\[ \mathcal{L}_{\text{SFT}}(\theta) = - \mathbb{E}_{(x, y) \sim D} \sum_t \log p_{f_\theta}(y_t \mid x, y_{<t}) \]

Standard mini-batch gradient descent (AdamW), standard LR schedule (warmup + cosine decay), standard loss tracking. Mechanically identical to pretraining; only the data is different.

The hard part isn't the optimization. It's preventing the optimization from breaking other things.

Catastrophic forgetting: the geometry

When you fine-tune, every gradient step is:

\[ \theta \leftarrow \theta - \eta \cdot \hat{g} \]

where \hat{g} is the gradient of the loss on the fine-tune data. Across K steps, the total displacement is:

\[ \Delta \theta = - \eta \sum_{k=1}^K \hat{g}_k \]

The fine-tune data tells you about its own distribution. It says nothing about other distributions. Yet Δθ happens in every coordinate of θ that has a non-zero gradient on the fine-tune data — including coordinates that were crucial for some unrelated task.

Catastrophic forgetting is the empirical fact that, for sufficiently large Δθ, the model's performance on the original (pretraining) distribution drops.

The standard demonstration: take a chat-tuned LLaMA-7B, fine-tune it for 1000 steps on medical Q/A. Measure its performance on general HellaSwag (a common-sense reasoning benchmark) before and after. After: HellaSwag accuracy drops by 5–15 percentage points.

Why LoRA helps geometrically

LoRA freezes most of θ. Only the LoRA matrices A, B update. So Δθ is constrained to a low-dimensional subspace of ℝ^{|θ|}.

The subspace consists of weight matrices of the form BA with rank ≤ r. For r ≪ min(in, out), this is a tiny subspace. The pretrained θ cannot move along any direction outside this subspace.

If the pretraining knowledge for unrelated tasks lives in directions outside the LoRA subspace, it's preserved automatically. This is the geometric reason LoRA reduces forgetting.

The empirical claim, which lab 02 measures: for typical fine-tuning, most pretraining knowledge sits in directions LoRA doesn't touch.

Why LoRA doesn't eliminate forgetting

LoRA's update BA is in some subspace of weight space. If part of that subspace happens to coincide with directions important for unrelated tasks, those tasks still degrade. The reduction is significant but not perfect.

Practical mitigations:

  1. Use a smaller learning rate. 1e-4 for LoRA fine-tuning (vs 1e-5 for full FT). Even with a higher LR per LoRA params, the total ||Δθ|| stays small because the LoRA subspace is small.
  2. Use replay. Mix a small fraction (~5%) of pretraining-distribution data into each batch. The gradient direction averages over both tasks, reducing drift on the original distribution.
  3. Use LoRA with a regularizer. Add λ ||BA||_F² to the loss to discourage large updates. Rarely needed in practice.

For Phase 28 we use the default LR (no replay, no regularizer) and rely on the structural advantage of LoRA. Lab 02 measures whether that's enough.

How to measure forgetting

You need two splits:

  • Task probe. Sequences that exercise the target capability. For our case: irregular-verb conjugation tests, e.g., He __ to school with the model expected to score went above goed. Accuracy goes up with fine-tuning.
  • Control probe. Sequences that exercise unrelated capabilities of the same model. For our case: regular-verb conjugation tests — She __ at home yesterday expecting worked, not worken. PPL should stay flat. The control split touches the 12 regular verbs (work, play, walk, talk, listen, watch, study, finish, start, look, want, like) that the model already handles correctly; we want them to stay handled correctly.

The metric is PPL drift on the control:

\[ \text{drift} = \frac{\text{PPL}_{\text{after}} - \text{PPL}_{\text{before}}}{\text{PPL}_{\text{before}}} \]

Phase 28's DoD bounds this at 5%.

Catastrophic forgetting vs catastrophic adoption

A related-but-different concern: the model adopts the fine-tune data's style or biases even where they shouldn't apply. E.g., a model fine-tuned to always flag irregular-verb patterns might start "correcting" regular verbs into irregular forms ("she walken yesterday" — nonsense, but plausible if the LoRA over-applies the irregular-verb prior).

This isn't forgetting — the pretrained capability for "regular -ed conjugation" might still be there. It's over-applying the fine-tune behaviour. Mitigation is the same (smaller LR, replay) plus prompt engineering at inference time.

We don't formally probe this in Phase 28 (would need an additional eval set), but it's worth being aware of.

SFT data preparation

For our case:

  • Source: the Phase 12 verb corpus, filtered to entries that exercise the 8 irregular verbs (be, have, do, go, come, see, eat, write) in past simple and past participle.
  • Augmentation: include both correct and deliberately incorrect forms as input-output pairs: input = He goed to the store → output = He went to the store. (correction: past simple of "go" is "went"). Include Spanish-pair translations as part of the output (per §A2 / §A13), e.g., wentfue / fui for "to go".
  • Format: standard prompt → completion pairs separated by a delimiter token (defined in Phase 13).

Lab 02 has the data preparation script.

Optimizer choice

AdamW is the default for LLM fine-tuning. Reasons:

  • Adaptive per-parameter learning rates handle gradient scale variance well.
  • The decoupled weight decay (W in AdamW) keeps the model from drifting too far from the pretrained θ.

LoRA fine-tuning works with SGD too, but AdamW is faster to converge and more robust. We use AdamW.

A note on RLHF

Once SFT has aligned the model's behaviour to the task format, RLHF (or DPO/ORPO etc.) can be used to align the model's preferences — i.e., when there are multiple correct completions, RLHF favours the one humans prefer. For Phase 28 we stop at SFT; the alignment-tuning theory is surveyed in theory 04 but not implemented.

Drill problems

Solutions at phase open in solutions/01-sft-forgetting-ref.md.

  1. State the cross-entropy loss for next-token prediction in terms of logits and token indices. Then state the gradient with respect to one logit z_t.
  2. Argue why a learning rate that's too high causes catastrophic forgetting more than a low LR. (Hint: think ||Δθ||.)
  3. LoRA's subspace has dimension r × (in + out) per Linear. For a Linear (in=768, out=768) with r=8, what fraction of the full weight space dimension is this?
  4. If the same task could be solved by prompting (giving the model a few-shot example at inference time), why do we fine-tune at all? Give two reasons.

One-paragraph recap

Supervised fine-tuning is mechanically identical to pretraining — same loss, same optimizer, different data. The risk is catastrophic forgetting: gradient steps along the fine-tune task can degrade unrelated capabilities. The geometric explanation: full fine-tuning moves θ in arbitrary directions, including directions important for the original distribution. LoRA constrains updates to a low-dimensional subspace, preserving most of θ and thus most of the pretrained capability. Mitigations like lower LR, replay, and regularization are available but often unnecessary. The next theory file derives the parameter-count savings precisely.

Next: theory/02-parameter-count.md.