English · Español
Lab 02 — The broadcasting trap¶
Goal: produce, intentionally, the
(N,) * (N,1) → (N,N)bug. Then fix it. Then catalog two more broadcasting surprises so they never surprise you again.Estimated time: 45–60 minutes.
Prereqs: lab 00.
What you produce¶
A directory experiments/06-broadcasting-trap/ containing:
bug.py— script that reproduces the bug and prints proof (wrong shape, wrong number).fix.py— same computation done correctly; prints proof.catalog.py— three more broadcasting situations; for each, prints input shapes, expected output shape, actual output shape (orValueError).manifest.json— standard schema.README.md— one paragraph per situation incatalog.pyexplaining why the broadcast resolves as it does.
TODOs¶
Block A — reproduce the classic bug¶
In bug.py:
- Generate "predictions"
y_pred = np.arange(10, dtype=np.float32)— shape(10,). - Generate "targets"
y_true = (np.arange(10, dtype=np.float32) + 0.1).reshape(10, 1)— shape(10, 1). - Compute
err = y_pred - y_true. Printerr.shape. (Should print(10, 10), not(10,).) - Compute
mse_wrong = (err ** 2).mean(). Print value. - Also compute the intended MSE
mse_right(use a manual loop or((y_pred - y_true.squeeze()) ** 2).mean()). Print value. - Assert
mse_wrong != mse_rightto make the bug undeniable.
Block B — fix it three ways¶
In fix.py, show three idiomatic fixes for the bug above:
- Match shapes explicitly:
y_true_flat = y_true.squeeze(). Then((y_pred - y_true_flat) ** 2).mean(). - Match the other way:
y_pred_col = y_pred[:, None]. Then((y_pred_col - y_true) ** 2).mean(). - Be paranoid:
assert y_pred.shape == y_true.squeeze().shape. Then proceed.
Print all three results. They should agree to within fp32 precision (~1e-7).
Block C — catalog three more broadcasting situations¶
In catalog.py, for each of the following, write a section: input shapes → predicted output shape → actual output shape → 1-paragraph explanation citing the broadcast rule.
Situation 1: a.shape = (3, 4), b.shape = (4,). Compute a + b.
Situation 2: a.shape = (3, 4), b.shape = (3,). Compute a + b. (Hint: this is the trap of "wanted to broadcast over rows but didn't reshape".)
Situation 3: a.shape = (B, 1, N, D), b.shape = (1, H, N, D) (attention-shaped). Compute a + b.
For each, your code must:
- Print the input shapes.
- Print your prediction of the output shape before computing (use a comment).
- Print the actual output shape.
- Catch any
ValueErrorand print it instead.
Block D — bind it to memory¶
- In
README.md, write three rules of thumb in your own words. They should map directly to: (i) align right, (ii) dims match or are 1, (iii) result is pairwise max. If you can't write the rule, you don't own it yet.
Constraints¶
- fp32. Match later phases.
- No
try / except Exception. Catch the specificValueErroryou expect; let unexpected exceptions crash so you notice. - Print, don't log. This lab is interactive; logging structure isn't the point here. (Exception to the lab 00 rule.)
Expected results¶
bug.pyprintserr.shape = (10, 10)and two different MSE values. The "wrong" one is approximately one-tenth of the "right" one (because you're averaging 100 entries where 90 are cross terms ~0–9 in magnitude).fix.pyprints three matching MSE values.catalog.pysituation 1:(3, 4). Situation 2:ValueError. Situation 3:(B, H, N, D).
Stop conditions¶
Done when:
- All three scripts run successfully (where success for
bug.pyincludes the assertion firing — that is the bug). catalog.pysituation 2 raisesValueErrorand your code catches it gracefully.README.mdhas the three rules of thumb in your own words.
Pitfalls¶
reshape(-1, 1)vs[:, None]. Functionally identical for 1-D arrays. Pick one and use it consistently.[:, None]is more idiomatic for inserting axes.squeezewith no axis argument. Removes all size-1 axes. Dangerous if you intended to remove a specific axis. Be explicit:squeeze(axis=-1).(N, 1) - (N, 1)does NOT broadcast to(N, N). Both shapes are(N, 1), fully compatible per-axis, result is(N, 1). The bug only happens when one is(N,)and the other is(N, 1).- Higher-dim shapes confuse your eye. Write out the right-aligned shapes before computing:
Then per-axis:
(B, H, N, D). Five seconds of paper saves an hour of debugging. - NumPy's error message. When broadcasting fails, NumPy prints the two shapes — read them. The shape that surprises you is the one that needs
[:, None]or.squeeze().
When to consult solutions/¶
After all three scripts work and README.md contains the rules in your own words. solutions/02-broadcasting-trap-ref.md (written at phase open) provides the reference explanations.
Next lab: lab/03-vectorization-budget.md.