Solutions Appendix
Chapter 15

Training Transformers

20 Solutions

Detailed solutions for the exercises in Chapter 15. Try solving them yourself before checking the answers.

Exercise 1Pen & Paper
Write the causal LM loss; explain the one-position shift; why does one pass give T signals?

Solution

The loss is −Σ_t log P(x_{t+1} | x_{≤1t}) — at each position the target is the NEXT token, so inputs and targets are the same sequence shifted by one. Because the causal mask lets every position predict its successor simultaneously in a single forward pass, one sequence of length T yields T separate next-token training signals — the dense supervision that makes language-model pretraining so sample-efficient per forward pass.

Exercise 2Pen & Paper
Loss 1.6 nats → perplexity? Interpret perplexity 5.

Solution

Perplexity = e^{loss} = e^{1.6} ≈ 4.95 ≈ 5. A perplexity of 5 means the model is on average as uncertain as if it were choosing uniformly among ~5 equally-likely next tokens — its effective branching factor. Lower perplexity = sharper, more confident (and usually more accurate) next-token predictions.

Exercise 3Pen & Paper
Why do Transformers need warmup but plain SGD on a CNN often doesn't? Connect to Adam.

Solution

Adam's second-moment estimate v is unreliable in the first steps (few samples), so early updates can be erratically large, destabilizing the sensitive Transformer (with its LayerNorms and residuals). Warmup ramps the learning rate up gradually so the moment estimates stabilize before large steps are taken. Plain SGD has no adaptive denominator to mis-estimate, and CNNs are less sensitive, so they often train fine without warmup.

Exercise 4Derive
Write warmup + cosine decay as a piecewise function; verify continuity at the boundary.

Solution

For step t with warmup W and total T: lr(t) = lr_max·(t/W) for t ≤ W (linear warmup), and lr(t) = lr_min + ½(lr_max−lr_min)(1 + cos(π·(t−W)/(T−W))) for t > W (cosine decay). At t = W, the warmup branch gives lr_max, and the cosine branch gives lr_min + ½(lr_max−lr_min)(1+cos0) = lr_min + (lr_max−lr_min) = lr_max. The two branches agree at the boundary, so the schedule is continuous.

Exercise 5Pen & Paper
Adam+L2 vs AdamW: show large-gradient params get less effective decay in Adam+L2.

Solution

In Adam+L2 the weight-decay term is added to the gradient and then divided by √v (the per-parameter second moment). Parameters with large gradients have large v, so their effective decay (decay/√v) is shrunk — they get less regularization precisely where you might want more. AdamW DECOUPLES weight decay, applying it directly to the weights (w ← w(1−ηλ)) outside the adaptive scaling, so every parameter is decayed uniformly regardless of its gradient magnitude — the reason AdamW generalizes better.

Exercise 6Pen & Paper
Why decay weight matrices but not LayerNorm gains/biases? What would decaying a gain toward 0 do?

Solution

Weight matrices benefit from shrinkage (regularization toward simpler functions). LayerNorm gains (γ) and biases set the SCALE and SHIFT of normalized activations; decaying γ toward 0 would suppress the layer's output magnitude, effectively killing the signal the normalization is meant to pass through, harming the model. Biases carry no capacity worth regularizing. So decay is applied to the large weight matrices only, and excluded from norms and biases.

Exercise 7Pen & Paper
Global-norm clipping scales gradients by max_norm/g_norm; prove it preserves direction.

Solution

When the total norm g_norm exceeds the threshold, every gradient is multiplied by the same positive scalar c = max_norm/g_norm < 1. Scaling all components of a vector by one positive constant changes its magnitude (to exactly max_norm) but not its direction (it still points the same way). So clipping rescales the step without changing where it points — it tames magnitude while keeping the descent direction intact.

Exercise 8Pen & Paper
Why has bf16 replaced fp16 for LLM training? Exponent range and loss scaling.

Solution

bf16 keeps fp32's 8-bit exponent (same dynamic range, ~10±³⁸) but with fewer mantissa bits, so it rarely overflows or underflows and needs NO loss scaling. fp16 has only a 5-bit exponent (narrow range), so small gradients underflow and require dynamic loss scaling to survive. bf16's wide range makes mixed-precision training simpler and more robust, which is why it became the default on modern hardware despite slightly lower precision.

Exercise 9Derive
Why does gradient accumulation divide each micro-batch loss by accumulation_steps?

Solution

The goal is for K micro-batches to produce the same gradient as one batch K times larger, i.e. the AVERAGE gradient over all examples. Summed micro-batch gradients without division would be K× too large (a sum, not a mean), effectively multiplying the learning rate by K. Dividing each micro-batch loss by K makes the accumulated gradients average correctly. Forgetting the division inflates the effective step size by K, often causing divergence.

Exercise 10Pen & Paper
Explain the critical batch size. Why do large models benefit from very large batches?

Solution

The critical batch size is the point below which gradient noise dominates (more samples per step help a lot) and above which returns diminish (the gradient estimate is already accurate, so larger batches mainly waste compute per step). Large models have smoother loss landscapes and tolerate — and benefit from — larger critical batch sizes, letting them use massive data parallelism efficiently. Below the critical size you are noise-limited; above it you are curvature-limited.

Exercise 11Code
Implement warmup + cosine decay; plot lr over 100k steps for warmup ∈ {500,2000,5000}.

Solution

The implemented schedule rises linearly to lr_max over the warmup window, then follows a cosine down to lr_min. Longer warmup delays the peak and gives a gentler start (more stable but slightly slower early progress); shorter warmup reaches full lr sooner (faster but riskier). The three curves visualize this trade-off.

Exercise 12Code
Implement configure_optimizer with decay/no-decay groups; print which params land where.

Solution

Iterating parameters and routing 2-D weight matrices to the decay group and 1-D tensors (LayerNorm gains, biases) to the no-decay group reproduces the rule of Exercise 6. Printing the groups confirms norms and biases are excluded from weight decay — the standard, correct optimizer configuration.

Exercise 13Code
Implement global-norm gradient clipping from scratch; verify it matches clip_grad_norm_.

Solution

Compute the total norm across all parameter gradients, and if it exceeds max_norm, scale every gradient by max_norm/total_norm (Exercise 7). The result matches PyTorch's clip_grad_norm_ on random gradients, confirming the implementation.

Exercise 14Code Lab
Train a small GPT with and without warmup; show no-warmup trains worse or diverges.

Solution

Without warmup, the early large Adam steps (unreliable second moments, Exercise 3) cause a loss spike or divergence; with warmup the loss descends smoothly. The two curves demonstrate empirically why warmup is standard for Transformer training.

Exercise 15Code
Compare fp32, fp16+GradScaler, bf16: peak memory, tokens/sec, final loss.

Solution

fp32 uses the most memory and is slowest; fp16 and bf16 roughly halve activation memory and speed up throughput on supported hardware. fp16 needs a GradScaler to avoid underflow; bf16 does not (Exercise 8). Final loss is essentially the same across all three, confirming mixed precision is a free speed/memory win when done correctly.

Exercise 16Code
Implement gradient accumulation; verify micro=4, accum=4 ≈ batch of 16.

Solution

Accumulating gradients over 4 micro-batches of 4 (with the 1/accum loss scaling of Exercise 9) and then stepping produces gradients nearly identical to a single batch of 16. This lets you simulate large batches under limited memory — verified by the matching gradients.

Exercise 17Code Lab
Measure gradient checkpointing: deep model with/without it; confirm ~33% overhead.

Solution

Checkpointing lowers peak memory markedly while increasing wall-clock time by roughly a third (the extra recomputation), reproducing the compute/memory trade-off of Chapter 11's Exercise 10 in a real training run.

Exercise 18Code
Build a monitoring dashboard: loss, perplexity, grad-norm, lr, update-to-weight ratio.

Solution

Logging and plotting these five signals reveals training health: perplexity = e^{loss}; grad-norm should be stable (spikes signal instability); lr follows the schedule; and the update-to-weight ratio (≈ lr·‖update‖/‖weight‖) should sit around 1e−3 — too high means unstable, too low means stalled. The dashboard is the practitioner's instrument panel.

Exercise 19Code
Implement save/load checkpoint with optimizer and RNG state; verify bit-identical resume.

Solution

Saving model weights, optimizer moments, scheduler step, and RNG state — then reloading and continuing — should produce a continuation bit-identical to an uninterrupted run. Verifying this confirms the checkpoint captures all training state, essential for long runs that span restarts.

Exercise 20Code (Challenge)
Assemble the full training loop; pretrain a small GPT on TinyStories/Shakespeare; coherent sample; post-mortem.

Solution

Combining the schedule, optimizer groups, clipping, mixed precision, accumulation, monitoring, and checkpointing into one loop and training a ~10M-param model yields coherent samples. The post-mortem should note observed instabilities (early loss spikes, grad-norm bursts) and the fixes (warmup, clipping, lr tuning) — the integrated payoff of the whole chapter.