Solutions Appendix
Chapter 20

Efficient Training Techniques

20 Solutions

Detailed solutions for the exercises in Chapter 20. Try solving them yourself before checking the answers.

Exercise 1Pen & Paper
Define MFU; for a 13B model at 25k tokens/s on an H100 (~990 TFLOP/s bf16), compute it using 6N.

Solution

MFU = (achieved useful FLOP/s)/(peak FLOP/s), with useful = 6N per token. Achieved = 6·13×10⁹·25000 = 1.95×10¹⁵ FLOP/s; dividing by 9.9×10¹⁴ gives ≈1.97. A ratio above 1 is impossible on a single device, signaling that 25k tokens/s must be an aggregate over several GPUs (or the figure is illustrative); on one H100 a realistic training MFU is 30–50%, i.e. a few thousand tokens/s for a 13B model. The exercise's value is in the method: MFU = 6N·throughput / peak.

Exercise 2Pen & Paper
Compute-bound vs memory-bound via arithmetic intensity; classify matmul, GELU, LayerNorm, softmax.

Solution

Arithmetic intensity = FLOPs per byte moved. High intensity → compute-bound (saturates the math units); low → memory-bound (limited by bandwidth). Large matmuls are compute-bound (lots of FLOPs reusing each loaded byte). GELU, LayerNorm, and softmax are memory-bound — they do little arithmetic per element but must stream the whole tensor through memory. This is why fusing the memory-bound ops (Exercise 4) yields big speedups.

Exercise 3Pen & Paper
Sketch the GPU memory hierarchy with sizes/bandwidths; why does keeping data in SRAM matter?

Solution

Registers (KB, fastest) → SRAM/shared memory (tens of MB on-chip, ~10–100 TB/s) → HBM (tens of GB, ~1–3 TB/s) → NVLink (inter-GPU, hundreds of GB/s) → InfiniBand (inter-node, ~tens of GB/s). SRAM is orders of magnitude faster than HBM, so for memory-bound ops the key to speed is doing as much work as possible while data sits in SRAM — exactly the principle behind FlashAttention and kernel fusion.

Exercise 4Pen & Paper
Quantify HBM traffic for unfused vs fused GELU(x@W+b); explain the speedup.

Solution

Unfused: the matmul writes its result to HBM, the bias-add reads and writes it, and GELU reads and writes it again — multiple full passes over the activation through slow HBM. Fused: the matmul output stays in registers/SRAM, bias and GELU are applied in place, and only the final result is written to HBM — one write instead of several read/write round-trips. For a memory-bound chain this can be a large speedup, since the bottleneck is HBM traffic, not arithmetic.

Exercise 5Derive
For checkpointing every √L layers, derive O(√L) memory and ~33% extra compute; state assumptions.

Solution

Place √L checkpoints; store activations only there (O(√L) memory for the checkpoints) and, during backward, recompute the activations within the current segment of √L layers (needing O(√L) more at a time). Total activation memory is O(√L) instead of O(L). The recomputation is one extra forward pass over each segment ≈ one extra forward overall; since backward is ≈2× forward, the added forward is ≈⅓ of the original forward+backward — hence ~33% compute overhead (assuming roughly equal-cost layers).

Exercise 6Pen & Paper
Explain the online-softmax trick; why does a running max and sum allow block-by-block attention?

Solution

Softmax needs the max and sum over the whole row for normalization, which seems to require the full score row. Online softmax processes the row in blocks while maintaining a running max and running (rescaled) sum: when a new block's larger max appears, the accumulated sum and output are rescaled by the exponential of the max difference. This yields the exact softmax incrementally, so attention can be computed block-by-block without ever holding the full row — the core of FlashAttention's memory savings.

Exercise 7Pen & Paper
Compare fp8 E4M3 and E5M2; why E4M3 for forward, E5M2 for gradients? What does per-tensor scaling do?

Solution

E4M3 has 4 exponent and 3 mantissa bits (more precision, narrower range); E5M2 has 5 exponent and 2 mantissa bits (more range, less precision). Forward activations are bounded and benefit from precision → E4M3. Gradients span a wide dynamic range and need the larger exponent to avoid underflow → E5M2. Per-tensor scaling multiplies each tensor by a factor to center its values in the representable range before casting, preventing overflow/underflow — making fp8 training viable.

Exercise 8Pen & Paper
What is a graph break in torch.compile and why does it hurt? Two causes and fixes.

Solution

torch.compile traces the model into a fused graph; a graph break is a point where it cannot trace (it falls back to eager mode, then resumes), fragmenting the graph and losing fusion/optimization opportunities. Common causes: (1) data-dependent Python control flow (e.g. if on a tensor value) — fix by using tensor ops or torch.cond; (2) calls to untraceable/opaque Python (printing, .item(), unsupported libraries) — fix by removing them from the hot path. Fewer breaks → larger fused graphs → more speedup.

Exercise 9Derive
For LoRA rank r on a d×d weight, derive 2dr trainable params; for d=4096, r=16, the reduction factor.

Solution

LoRA freezes W and learns a low-rank update ΔW = BA, with B (d×r) and A (r×d): trainable params = dr + rd = 2dr. For d=4096, r=16: 2·4096·16 = 131,072. Full fine-tuning would train d² = 16.78M. The reduction is d²/(2dr) = d/(2r) = 4096/32 = 128× fewer trainable parameters — the efficiency that makes adapter fine-tuning practical.

Exercise 10Pen & Paper
List five efficiency techniques; for each, whether it saves memory, compute, or memory traffic; how they compose.

Solution

Mixed precision (memory + traffic + compute). Activation checkpointing (memory, costs compute). Kernel fusion / FlashAttention (memory traffic). torch.compile (compute + traffic via fusion). LoRA (memory + compute for fine-tuning). They compose because they target different bottlenecks: e.g. FlashAttention (traffic) + checkpointing (memory) + compile (fusion) + bf16 (all) stack multiplicatively, which is why production training enables several at once (Exercise 20).

Exercise 11Code
Compute MFU for a real step: time it, count tokens, apply 6N, divide by peak; identify the bottleneck.

Solution

Timing a step, multiplying tokens by 6N for useful FLOPs, and dividing by the GPU's peak gives the MFU (Exercise 1). A low MFU with the profiler showing time in attention/normalization points to a memory-bandwidth bottleneck; time in communication points to a parallelism bottleneck — guiding which optimization to apply.

Exercise 12Code
Profile a step with the PyTorch profiler; identify the top three ops and classify each.

Solution

The profiler's table ranks operators by time; typically matmuls (compute-bound), attention (memory-bound, unless FlashAttention), and elementwise/normalization (memory-bound) dominate. Classifying each by arithmetic intensity (Exercise 2) tells you whether to fuse (memory-bound) or improve utilization (compute-bound).

Exercise 13Code Lab
Measure activation recomputation: deep model with/without checkpointing; confirm ~33% overhead.

Solution

Enabling gradient_checkpointing markedly lowers peak memory while increasing wall-clock by roughly a third — confirming the O(√L) memory / ~33% compute trade-off derived in Exercise 5.

Exercise 14Code
Demonstrate fusion: GELU(x@W+b) as separate ops vs one fused expression; measure wall-clock and memory.

Solution

The fused version avoids the extra HBM round-trips of Exercise 4, running faster and using less peak memory on a large tensor — a concrete measurement of why operator fusion matters for memory-bound chains.

Exercise 15Code Lab
Write a Triton vector-add kernel; verify vs torch; extend to fused bias+GELU; benchmark.

Solution

A basic Triton kernel matches torch's add; fusing bias and GELU into one kernel (one read, one write) beats the unfused PyTorch sequence by avoiding intermediate HBM traffic — a hands-on demonstration of writing custom fused kernels, the lowest level of the efficiency stack.

Exercise 16Code
Apply torch.compile to a small model; measure speedup vs eager; find graph breaks with dynamo.explain.

Solution

torch.compile fuses and optimizes the graph, giving a measurable speedup over eager mode for both training and inference. torch._dynamo.explain surfaces any graph breaks (Exercise 8); removing them widens the fused graph and increases the speedup.

Exercise 17Code
Compare from-scratch attention to SDPA (FlashAttention) at T=1k/4k/8k; plot speed and peak memory.

Solution

The naive implementation's time and memory grow steeply with T (materializing the T×T matrix), while FlashAttention scales far better in memory and stays faster — the plot makes the IO-aware advantage of Exercise 6/9 vivid across context lengths.

Exercise 18Code
Enable fused AdamW (fused=True); measure optimizer-step time vs default across parameter counts.

Solution

The fused AdamW kernel performs the moment updates in a single fused operation rather than many small elementwise kernels, reducing launch overhead and memory traffic; the speedup over the default grows with parameter count — a cheap, drop-in efficiency win.

Exercise 19Code
Implement a minimal LoRA layer: freeze W, add trainable BA; verify only 2dr params get gradients.

Solution

Freezing the base linear and adding a low-rank BA update (Exercise 9) means only the 2dr adapter parameters receive gradients; verifying the gradient count and that the layer still trains demonstrates parameter-efficient fine-tuning in a few lines.

Exercise 20Code (Challenge)
Efficiency ablation: (a) eager bf16, (b) +FlashAttention, (c) +torch.compile, (d) +recompute; report MFU/throughput/memory.

Solution

Layering the techniques and measuring each configuration shows their cumulative effect: FlashAttention cuts attention memory/time, torch.compile fuses the rest, and checkpointing lowers peak memory (at some compute cost). The write-up should identify the biggest win for the given model/hardware — usually FlashAttention or compile for compute, checkpointing for memory — and explain why, tying back to which bottleneck each addresses.