Skip to content

English · Español

Lab 02 — The broadcasting trap

Goal: produce, intentionally, the (N,) * (N,1) → (N,N) bug. Then fix it. Then catalog two more broadcasting surprises so they never surprise you again.

Estimated time: 45–60 minutes.

Prereqs: lab 00.


What you produce

A directory experiments/06-broadcasting-trap/ containing:

  • bug.py — script that reproduces the bug and prints proof (wrong shape, wrong number).
  • fix.py — same computation done correctly; prints proof.
  • catalog.py — three more broadcasting situations; for each, prints input shapes, expected output shape, actual output shape (or ValueError).
  • manifest.json — standard schema.
  • README.md — one paragraph per situation in catalog.py explaining why the broadcast resolves as it does.

TODOs

Block A — reproduce the classic bug

In bug.py:

  • Generate "predictions" y_pred = np.arange(10, dtype=np.float32) — shape (10,).
  • Generate "targets" y_true = (np.arange(10, dtype=np.float32) + 0.1).reshape(10, 1) — shape (10, 1).
  • Compute err = y_pred - y_true. Print err.shape. (Should print (10, 10), not (10,).)
  • Compute mse_wrong = (err ** 2).mean(). Print value.
  • Also compute the intended MSE mse_right (use a manual loop or ((y_pred - y_true.squeeze()) ** 2).mean()). Print value.
  • Assert mse_wrong != mse_right to make the bug undeniable.

Block B — fix it three ways

In fix.py, show three idiomatic fixes for the bug above:

  1. Match shapes explicitly: y_true_flat = y_true.squeeze(). Then ((y_pred - y_true_flat) ** 2).mean().
  2. Match the other way: y_pred_col = y_pred[:, None]. Then ((y_pred_col - y_true) ** 2).mean().
  3. Be paranoid: assert y_pred.shape == y_true.squeeze().shape. Then proceed.

Print all three results. They should agree to within fp32 precision (~1e-7).

Block C — catalog three more broadcasting situations

In catalog.py, for each of the following, write a section: input shapes → predicted output shape → actual output shape → 1-paragraph explanation citing the broadcast rule.

Situation 1: a.shape = (3, 4), b.shape = (4,). Compute a + b.

Situation 2: a.shape = (3, 4), b.shape = (3,). Compute a + b. (Hint: this is the trap of "wanted to broadcast over rows but didn't reshape".)

Situation 3: a.shape = (B, 1, N, D), b.shape = (1, H, N, D) (attention-shaped). Compute a + b.

For each, your code must:

  • Print the input shapes.
  • Print your prediction of the output shape before computing (use a comment).
  • Print the actual output shape.
  • Catch any ValueError and print it instead.

Block D — bind it to memory

  • In README.md, write three rules of thumb in your own words. They should map directly to: (i) align right, (ii) dims match or are 1, (iii) result is pairwise max. If you can't write the rule, you don't own it yet.

Constraints

  • fp32. Match later phases.
  • No try / except Exception. Catch the specific ValueError you expect; let unexpected exceptions crash so you notice.
  • Print, don't log. This lab is interactive; logging structure isn't the point here. (Exception to the lab 00 rule.)

Expected results

  • bug.py prints err.shape = (10, 10) and two different MSE values. The "wrong" one is approximately one-tenth of the "right" one (because you're averaging 100 entries where 90 are cross terms ~0–9 in magnitude).
  • fix.py prints three matching MSE values.
  • catalog.py situation 1: (3, 4). Situation 2: ValueError. Situation 3: (B, H, N, D).

Stop conditions

Done when:

  1. All three scripts run successfully (where success for bug.py includes the assertion firing — that is the bug).
  2. catalog.py situation 2 raises ValueError and your code catches it gracefully.
  3. README.md has the three rules of thumb in your own words.

Pitfalls

  • reshape(-1, 1) vs [:, None]. Functionally identical for 1-D arrays. Pick one and use it consistently. [:, None] is more idiomatic for inserting axes.
  • squeeze with no axis argument. Removes all size-1 axes. Dangerous if you intended to remove a specific axis. Be explicit: squeeze(axis=-1).
  • (N, 1) - (N, 1) does NOT broadcast to (N, N). Both shapes are (N, 1), fully compatible per-axis, result is (N, 1). The bug only happens when one is (N,) and the other is (N, 1).
  • Higher-dim shapes confuse your eye. Write out the right-aligned shapes before computing:
    (B, 1, N, D)
    (1, H, N, D)   ← right-aligned
    
    Then per-axis: (B, H, N, D). Five seconds of paper saves an hour of debugging.
  • NumPy's error message. When broadcasting fails, NumPy prints the two shapes — read them. The shape that surprises you is the one that needs [:, None] or .squeeze().

When to consult solutions/

After all three scripts work and README.md contains the rules in your own words. solutions/02-broadcasting-trap-ref.md (written at phase open) provides the reference explanations.


Next lab: lab/03-vectorization-budget.md.