Skip to content

English · Español

03 — Tied embeddings and the LM head

🇪🇸 La cabeza del modelo lingüístico reutiliza la misma matriz que el embedding de entrada. Un truco de un solo carácter en el código que ahorra \(|V| \cdot d_\text{model}\) parámetros y unifica la entrada y salida en un mismo espacio. Aquí no salvamos mucha memoria — el corpus es pequeño — pero entendemos el principio.

The LM head

After the final LayerNorm, the residual stream sits in \(\mathbb{R}^{T \times d_\text{model}}\). To turn this into a probability distribution over vocabulary tokens at each position, we apply a linear projection to vocab size, then softmax:

\[\text{logits}_t = h_t \cdot W_\text{LM}^\top \quad \in \mathbb{R}^{|V|}$$ $$p_t = \text{softmax}(\text{logits}_t)\]

where \(W_\text{LM} \in \mathbb{R}^{|V| \times d_\text{model}}\). Naively, this is a new learnable matrix with \(|V| \cdot d_\text{model}\) params.

But notice: the input embedding \(E \in \mathbb{R}^{|V| \times d_\text{model}}\) has the same shape. A vector "embeds" a token id; the transpose "unembeds" a residual back to a token id. The two operations are inverses of each other in a meaningful sense.

Weight tying

Tied embeddings (Press & Wolf 2017, "Using the Output Embedding to Improve Language Models"; concurrent in Inan et al. 2017): set \(W_\text{LM} = E\). The input embedding matrix is the output projection matrix. The forward becomes:

\[\text{logits}_t = h_t \cdot E^\top\]

That's it — same matrix, used twice. Pattern:

class MiniGPT:
    def __init__(self, vocab_size, d_model, ...):
        self.E = Parameter(np.random.randn(vocab_size, d_model) * 0.02)
        # ... blocks, LNs, etc.
        # NO self.W_LM. The LM head reuses self.E.

    def forward(self, tokens):
        h = self.E[tokens]                # (T, d_model)  — token embed
        for block in self.blocks:
            h = block(h)
        h = self.ln_final(h)
        logits = h @ self.E.T             # (T, vocab_size) — LM head, tied
        return logits

Two benefits

1. Parameter savings

Tying eliminates one \(|V| \cdot d_\text{model}\) matrix. For Mini-GPT (\(|V| = 64, d_\text{model} = 64\)), that's 4096 params — small. But for GPT-2 (\(|V| = 50257, d_\text{model} = 768\)), it's 38.5M params — out of 124M total, ~31% of the model. At LLaMA-2 scale (\(|V| = 32000, d_\text{model} = 4096\)), it's 131M params. The savings are substantial at real scale.

For Mini-GPT, the savings are pedagogical, not practical. We tie anyway because:

  • It teaches the principle.
  • It makes the embedding matrix's direction meaningful — directions that improve the input embedding also improve the output projection.
  • It is what every modern transformer does.

2. Conceptual symmetry

Tying says: "the same notion of what a token is (the embedding) determines what a hidden state means about a token (the unembedding)." There is one vocabulary space, used twice. This is satisfying conceptually and turns out to be true empirically — independently-trained input and output embeddings end up close to each other anyway. Tying just imposes that prior.

Mechanistic interpretability rests on this: the unembed direction for token \(w\) is the row \(E[w]\), and you can ask "which residual directions point at \(E[w]\)?" — this is the "logit lens" technique (nostalgebraist 2020). Tying makes the lens well-defined.

The full Mini-GPT forward, with tied head

tokens (T,)   →   E[tokens]              shape (T, d_model)
             block_0                     shape (T, d_model)
             block_1                     shape (T, d_model)
             LN_final                    shape (T, d_model)
             @ E.T                       shape (T, vocab_size)
             logits → (softmax in loss; not part of model proper)

Why no softmax inside MiniGPT.forward?

The forward returns logits, not probabilities. The softmax happens:

  • Inside the loss function (Phase 18) — fused with the loss for numerical stability (the cross_entropy_from_logits trick from Phase 05).
  • Inside sampling (Phase 21) — possibly with temperature scaling.

Decoupling forward from softmax means:

  • Numerical stability: the loss can use the log-sum-exp trick instead of computing log(softmax(...)) directly.
  • Flexibility: at inference, you can apply temperature, top-k, top-p without recomputing.

Mini-GPT's .forward() returns logits: (T, V). Anything downstream chooses what to do with them.

A subtle point: gradient flow with tying

When you tie \(W_\text{LM} = E\), the gradient \(\partial \mathcal{L} / \partial E\) has two contributions: one from the input embedding lookup, one from the output projection. Autograd handles this automatically if you wire \(E\) as a single Parameter referenced twice — exactly what the code above does. If you accidentally make two copies (e.g., self.W_LM = E.copy()), you've untied them and lost the property.

Phase 18 will look at the gradient flow into the tied \(E\) as a sanity check; the two contributions sum and the resulting update direction is meaningful.

Initialization

For tied embeddings, initialize \(E\) once with the embedding-style init (small Gaussian, typically \(\mathcal{N}(0, 0.02^2)\)). The output projection inherits this init — no separate init needed. This is the GPT-2 default.

For untied LM heads, you'd typically initialize the head with the same scale, since it's effectively another linear layer.

The final softmax — where it lives at inference

In Phase 21 (sampling), the model's logits become probabilities via:

\[q_t = \text{softmax}(\text{logits}_t / \tau)\]

where \(\tau\) is the temperature (\(\tau = 1\) is "raw," \(\tau < 1\) sharpens, \(\tau > 1\) flattens). Phase 21 covers temperature, top-k, top-p, nucleus. None of this is part of Phase 17. Phase 17's job ends at logits.

What this file does NOT cover

  • Sampling strategies. Phase 21.
  • The cross-entropy loss using logits. Phase 18 (with the numerical-stability trick from Phase 05).
  • Untied LM heads. Mentioned for completeness; not used.

Next: ../lab/00-block-by-hand.md