Efficient Training Techniques
Detailed solutions for the exercises in Chapter 20. Try solving them yourself before checking the answers.
Solution
MFU = (achieved useful FLOP/s)/(peak FLOP/s), with useful = 6N per token. Achieved = 6·13×10⁹·25000 = 1.95×10¹⁵ FLOP/s; dividing by 9.9×10¹⁴ gives ≈1.97. A ratio above 1 is impossible on a single device, signaling that 25k tokens/s must be an aggregate over several GPUs (or the figure is illustrative); on one H100 a realistic training MFU is 30–50%, i.e. a few thousand tokens/s for a 13B model. The exercise's value is in the method: MFU = 6N·throughput / peak.
Solution
Arithmetic intensity = FLOPs per byte moved. High intensity → compute-bound (saturates the math units); low → memory-bound (limited by bandwidth). Large matmuls are compute-bound (lots of FLOPs reusing each loaded byte). GELU, LayerNorm, and softmax are memory-bound — they do little arithmetic per element but must stream the whole tensor through memory. This is why fusing the memory-bound ops (Exercise 4) yields big speedups.
Solution
Registers (KB, fastest) → SRAM/shared memory (tens of MB on-chip, ~10–100 TB/s) → HBM (tens of GB, ~1–3 TB/s) → NVLink (inter-GPU, hundreds of GB/s) → InfiniBand (inter-node, ~tens of GB/s). SRAM is orders of magnitude faster than HBM, so for memory-bound ops the key to speed is doing as much work as possible while data sits in SRAM — exactly the principle behind FlashAttention and kernel fusion.
Solution
Unfused: the matmul writes its result to HBM, the bias-add reads and writes it, and GELU reads and writes it again — multiple full passes over the activation through slow HBM. Fused: the matmul output stays in registers/SRAM, bias and GELU are applied in place, and only the final result is written to HBM — one write instead of several read/write round-trips. For a memory-bound chain this can be a large speedup, since the bottleneck is HBM traffic, not arithmetic.
Solution
Place √L checkpoints; store activations only there (O(√L) memory for the checkpoints) and, during backward, recompute the activations within the current segment of √L layers (needing O(√L) more at a time). Total activation memory is O(√L) instead of O(L). The recomputation is one extra forward pass over each segment ≈ one extra forward overall; since backward is ≈2× forward, the added forward is ≈⅓ of the original forward+backward — hence ~33% compute overhead (assuming roughly equal-cost layers).
Solution
Softmax needs the max and sum over the whole row for normalization, which seems to require the full score row. Online softmax processes the row in blocks while maintaining a running max and running (rescaled) sum: when a new block's larger max appears, the accumulated sum and output are rescaled by the exponential of the max difference. This yields the exact softmax incrementally, so attention can be computed block-by-block without ever holding the full row — the core of FlashAttention's memory savings.
Solution
E4M3 has 4 exponent and 3 mantissa bits (more precision, narrower range); E5M2 has 5 exponent and 2 mantissa bits (more range, less precision). Forward activations are bounded and benefit from precision → E4M3. Gradients span a wide dynamic range and need the larger exponent to avoid underflow → E5M2. Per-tensor scaling multiplies each tensor by a factor to center its values in the representable range before casting, preventing overflow/underflow — making fp8 training viable.
Solution
torch.compile traces the model into a fused graph; a graph break is a point where it cannot trace (it falls back to eager mode, then resumes), fragmenting the graph and losing fusion/optimization opportunities. Common causes: (1) data-dependent Python control flow (e.g. if on a tensor value) — fix by using tensor ops or torch.cond; (2) calls to untraceable/opaque Python (printing, .item(), unsupported libraries) — fix by removing them from the hot path. Fewer breaks → larger fused graphs → more speedup.
Solution
LoRA freezes W and learns a low-rank update ΔW = BA, with B (d×r) and A (r×d): trainable params = dr + rd = 2dr. For d=4096, r=16: 2·4096·16 = 131,072. Full fine-tuning would train d² = 16.78M. The reduction is d²/(2dr) = d/(2r) = 4096/32 = 128× fewer trainable parameters — the efficiency that makes adapter fine-tuning practical.
Solution
Mixed precision (memory + traffic + compute). Activation checkpointing (memory, costs compute). Kernel fusion / FlashAttention (memory traffic). torch.compile (compute + traffic via fusion). LoRA (memory + compute for fine-tuning). They compose because they target different bottlenecks: e.g. FlashAttention (traffic) + checkpointing (memory) + compile (fusion) + bf16 (all) stack multiplicatively, which is why production training enables several at once (Exercise 20).
Solution
Timing a step, multiplying tokens by 6N for useful FLOPs, and dividing by the GPU's peak gives the MFU (Exercise 1). A low MFU with the profiler showing time in attention/normalization points to a memory-bandwidth bottleneck; time in communication points to a parallelism bottleneck — guiding which optimization to apply.
Solution
The profiler's table ranks operators by time; typically matmuls (compute-bound), attention (memory-bound, unless FlashAttention), and elementwise/normalization (memory-bound) dominate. Classifying each by arithmetic intensity (Exercise 2) tells you whether to fuse (memory-bound) or improve utilization (compute-bound).
Solution
Enabling gradient_checkpointing markedly lowers peak memory while increasing wall-clock by roughly a third — confirming the O(√L) memory / ~33% compute trade-off derived in Exercise 5.
Solution
The fused version avoids the extra HBM round-trips of Exercise 4, running faster and using less peak memory on a large tensor — a concrete measurement of why operator fusion matters for memory-bound chains.
Solution
A basic Triton kernel matches torch's add; fusing bias and GELU into one kernel (one read, one write) beats the unfused PyTorch sequence by avoiding intermediate HBM traffic — a hands-on demonstration of writing custom fused kernels, the lowest level of the efficiency stack.
Solution
torch.compile fuses and optimizes the graph, giving a measurable speedup over eager mode for both training and inference. torch._dynamo.explain surfaces any graph breaks (Exercise 8); removing them widens the fused graph and increases the speedup.
Solution
The naive implementation's time and memory grow steeply with T (materializing the T×T matrix), while FlashAttention scales far better in memory and stays faster — the plot makes the IO-aware advantage of Exercise 6/9 vivid across context lengths.
Solution
The fused AdamW kernel performs the moment updates in a single fused operation rather than many small elementwise kernels, reducing launch overhead and memory traffic; the speedup over the default grows with parameter count — a cheap, drop-in efficiency win.
Solution
Freezing the base linear and adding a low-rank BA update (Exercise 9) means only the 2dr adapter parameters receive gradients; verifying the gradient count and that the layer still trains demonstrates parameter-efficient fine-tuning in a few lines.
Solution
Layering the techniques and measuring each configuration shows their cumulative effect: FlashAttention cuts attention memory/time, torch.compile fuses the rest, and checkpointing lowers peak memory (at some compute cost). The write-up should identify the biggest win for the given model/hardware — usually FlashAttention or compile for compute, checkpointing for memory — and explain why, tying back to which bottleneck each addresses.