Part III: The Transformer
Chapter 15

Building a GPT from Scratch

Implement, train, and generate from a real language model
20 Exercises
15.1

With the architecture built (Chapter 13) and text tokenized (Chapter 14), we can finally train. The pretraining objective for a decoder-only Transformer is the simplest imaginable: given a sequence of tokens, predict the next one at every position. This single objective, scaled up, produces all the capabilities of modern LLMs.

The Loss

textCausal language-modeling loss
For a sequence x₁ ... x_T:

L = -(1/T) Σₜ log Pθ(xₜ₊₁ | x₁ ... xₜ)

# Cross-entropy (Chapter 4) at every position,
# averaged over the sequence. That is the entire objective.

Because of the causal mask (Chapter 12), every position predicts its successor using only earlier tokens, and all T predictions are computed in a single forward pass. This is the efficiency that makes the decoder-only design so attractive: one forward pass yields T training signals, one per token.

Teacher Forcing and the Shift

During training we feed the true tokens and ask the model to predict each next token — this is teacher forcing. Implementation reduces to a one-position shift: the inputs are tokens x₁...x_{T-1} and the targets are x₂...x_T. The model's logits at position t are scored against the actual token at position t+1.

PythonThe training objective in code
import torch; import torch.nn.functional as F

def lm_loss(model, tokens):  # tokens: (B, T+1)
    """Causal LM loss via the one-position shift."""
    inputs  = tokens[:, :-1]        # x_1 ... x_T
    targets = tokens[:, 1:]        # x_2 ... x_{T+1}

    logits = model(inputs)           # (B, T, V)
    # Flatten to (B*T, V) and (B*T,) for cross-entropy
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
    )
    return loss

# Perplexity = exp(loss) is the standard human-readable metric (Chapter 4).
# A loss of 2.0 nats ≈ perplexity 7.4: the model is effectively choosing
# among ~7 equally-likely next tokens at each step.
Train Note: Document Packing
Real pretraining concatenates many documents into long token streams and slices them into fixed-length sequences, separated by an end-of-text token. This 'packing' avoids wasting compute on padding — every position in every sequence carries a real training signal.
Some pipelines reset the attention mask at document boundaries so tokens cannot attend across unrelated documents; others let attention cross boundaries for simplicity. Both are used in practice; the former is cleaner, the latter is simpler and works well at scale.
15.2

Transformers are unusually sensitive to the learning rate early in training. Starting at the full learning rate often causes immediate divergence. The fix — learning-rate warmup — ramps the rate up linearly from near zero over the first few thousand steps, then decays it. This single trick is essential for stable Transformer training.

Why Early Training Is Fragile

At initialization the model's predictions are random and the gradients are large and noisy. Adam's second-moment estimate (the running variance of gradients) has not yet stabilized, so early updates can be wildly miscalibrated. A large learning rate amplifies these bad early updates, pushing the model into a region from which it cannot recover. Warmup gives Adam's statistics time to settle before applying full-strength updates.

Train Note: Warmup and Adam's Variance Estimate
Adam divides each update by √(second moment), an estimate of the gradient's variance. In the first few steps this estimate is based on almost no data and is unreliable — it can be far too small, producing enormous effective step sizes.
Warmup compensates: by keeping the learning rate tiny while Adam's variance estimate accumulates evidence, it prevents the catastrophic early steps. This is why warmup matters more for adaptive optimizers like Adam than for plain SGD.

Warmup + Cosine Decay

The standard modern schedule combines linear warmup with cosine decay: ramp up linearly to a peak learning rate over the warmup steps, then decay smoothly following a cosine curve down to a small final rate. This schedule trains the model fast once warmed up, then anneals to a low rate for fine convergence.

textWarmup + cosine decay schedule
if step < warmup:
    lr = peak_lr · (step / warmup)              # linear ramp up
else:
    progress = (step - warmup) / (total - warmup)
    lr = min_lr + 0.5(peak_lr - min_lr)(1 + cos(π · progress))
PythonWarmup + cosine decay schedule from scratch
import math

def lr_schedule(step, peak_lr, warmup, total, min_lr=0.0):
    """Linear warmup then cosine decay to min_lr."""
    if step < warmup:                    # ramp up
        return peak_lr * step / warmup
    if step > total:                     # past the end
        return min_lr
    progress = (step - warmup) / (total - warmup)  # 0 -> 1
    cosine = 0.5 * (1 + math.cos(math.pi * progress))  # 1 -> 0
    return min_lr + (peak_lr - min_lr) * cosine

# Typical large-model schedule
peak, warmup, total = 3e-4, 2000, 100000
for s in [0, 1000, 2000, 50000, 100000]:
    print(f"step {s:>6}: lr = {lr_schedule(s, peak, warmup, total):.2e}")
# step      0: lr = 0.00e+00   (start from zero)
# step   1000: lr = 1.50e-04   (halfway through warmup)
# step   2000: lr = 3.00e-04   (peak, warmup complete)
# step  50000: lr = 1.51e-04   (cosine decay)
# step 100000: lr = 0.00e+00   (annealed to zero)
15.3

Chapter 2 derived Adam; here we focus on AdamW, the variant used to train essentially every modern LLM. AdamW's key correction — decoupling weight decay from the gradient update — makes regularization behave correctly with adaptive learning rates.

The Adam Update, Recalled

textAdam: per-parameter adaptive steps
m ← β₁ m + (1-β₁) g           # 1st moment (momentum)
v ← β₂ v + (1-β₂) g²          # 2nd moment (variance)
m̂ = m / (1-β₁ᵗ),  v̂ = v / (1-β₂ᵗ)   # bias correction
θ ← θ - lr · m̂ / (√v̂ + ε)       # adaptive update

The W in AdamW: Decoupled Weight Decay

Plain Adam with L2 regularization adds λθ to the gradient — but then this term gets divided by √v̂ along with everything else, so parameters with large gradients get less regularization. AdamW (Loshchilov & Hutter, 2017) fixes this by applying weight decay directly to the parameters, separately from the adaptive gradient step.

textL2 regularization vs decoupled weight decay
Adam + L2:   g ← g + λθ;  then adaptive step (decay gets scaled by 1/√v̂)

AdamW:       θ ← θ - lr(m̂/(√v̂+ε) + λθ)   # decay applied directly
             # weight decay is NOT scaled by the adaptive denominator
HyperparameterTypical valueRole
β₁0.9Momentum: smooths the gradient direction
β₂0.95 (LLMs) / 0.999Variance: how fast the scale adapts
ε1e-8Numerical floor in the denominator
weight decay λ0.1Regularization strength (decoupled)
peak lr1e-4 to 6e-4Scaled down as models grow
Train Note: Don't Decay Everything
Standard practice is to apply weight decay only to matrix weights, NOT to biases, LayerNorm/RMSNorm gains, or embeddings. Decaying a LayerNorm gain toward zero would suppress the normalization; decaying biases serves no regularization purpose.
Implementations split parameters into two groups — 'decay' (the 2D weight matrices) and 'no_decay' (1D parameters: biases, norm gains, sometimes embeddings) — and pass them to AdamW with different weight_decay values. Getting this split right is a small but real quality lever.
PythonConfiguring AdamW with parameter groups
import torch

def configure_optimizer(model, lr=3e-4, wd=0.1):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if p.ndim >= 2:                 # matrices: decay
            decay.append(p)
        else:                            # biases, norms: no decay
            no_decay.append(p)
    groups = [
        {'params': decay,    'weight_decay': wd},
        {'params': no_decay, 'weight_decay': 0.0},
    ]
    return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)

# betas=(0.9, 0.95) is the LLM default -- lower β₂ than the 0.999 used
# for smaller models, giving faster adaptation to changing gradient scales.
15.4

Even with warmup and a good optimizer, training can produce occasional huge gradients — from a difficult batch, a rare token, or accumulated instability. A single enormous update can undo thousands of steps of progress or send the loss to NaN. Gradient clipping caps the gradient's magnitude, keeping every update bounded.

Clipping by Global Norm

The standard method clips by global norm: compute the L2 norm of all gradients concatenated together, and if it exceeds a threshold, scale every gradient down by the same factor so the total norm equals the threshold. This preserves the gradient's direction while bounding its size.

textGlobal-norm gradient clipping
g_norm = √(Σ_all_params ‖gᵢ‖²)        # total gradient norm

if g_norm > max_norm:
    gᵢ ← gᵢ · (max_norm / g_norm)  for all i   # scale down, keep direction
PythonGradient clipping in the training loop
import torch.nn as nn

# Standard placement: after backward, before optimizer.step
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

# clip_grad_norm_ RETURNS the pre-clip norm -- log it to monitor training
# Healthy: grad_norm hovers around a stable value (e.g. 0.5-2.0)
# Warning: frequent clipping (grad_norm >> max_norm) signals instability
# Spike:   a sudden 100x grad_norm is a loss spike being clipped away
Train Note: Grad Norm Is a Vital Sign
The gradient norm is one of the most informative training diagnostics. Logged over time, it reveals the health of the run: a stable norm means smooth training; frequent large spikes mean the model is repeatedly hitting difficult regions; a sustained climb often precedes divergence.
Many large-model training runs plot grad-norm alongside loss. When a loss spike occurs, the grad-norm spike usually precedes it — clipping is what stands between a recoverable bump and a ruined run.
⚠️
Pitfall: Clipping Masks but Does Not Cure
Gradient clipping prevents a single bad update from destroying the model, but if you are clipping constantly, something is wrong upstream — the learning rate is too high, the warmup too short, or the data contains pathological batches. Clipping is a safety net, not a fix.
If grad-norm consistently exceeds your clip threshold by a large margin, lower the learning rate or lengthen warmup rather than just clipping harder. Persistent heavy clipping degrades the optimizer's behaviour.
15.5

Chapter 5 introduced the floating-point formats and the mechanics of mixed precision. Here we apply them to Transformer training. The goal: do the bulk of computation in 16-bit precision (bf16 or fp16) for speed and memory, while keeping a master copy of the weights and the optimizer states in 32-bit for accuracy.

bf16 vs fp16 for Training

fp16 (half)bf16 (bfloat16)
10-bit mantissa, 5-bit exponent7-bit mantissa, 8-bit exponent
More precision, less rangeLess precision, full fp32 range
Overflows easily (max ~65504)Rarely overflows (max ~3.4e38)
NEEDS loss scalingNo loss scaling needed
Older GPUs (V100)Modern GPUs (A100, H100, TPU)
Fiddly to stabilizeDrop-in, robust — the modern default

For LLM training, bf16 has decisively won. Its full fp32 exponent range means gradients and activations almost never overflow, eliminating the loss-scaling machinery that fp16 requires. The cost — fewer mantissa bits, so less precision — turns out not to matter for training, where range is more important than precision.

PythonMixed-precision training loop (bf16)
import torch

model = model.cuda()
opt   = configure_optimizer(model, lr=3e-4)

for tokens in dataloader:
    tokens = tokens.cuda()
    opt.zero_grad()

    # Autocast: forward + loss in bf16, no scaler needed
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        loss = lm_loss(model, tokens)

    loss.backward()                        # gradients accumulate in fp32
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()

# Compare to fp16, which additionally requires:
#   scaler = torch.cuda.amp.GradScaler()
#   scaler.scale(loss).backward(); scaler.unscale_(opt)
#   scaler.step(opt); scaler.update()
# bf16's full range makes all of that unnecessary.
Train Note: Keep Master Weights and Optimizer States in fp32
Mixed precision does NOT mean everything is 16-bit. The master weights, the Adam moment estimates (m and v), and the loss accumulation stay in fp32. Only the forward/backward compute and the gradients use 16-bit. This preserves the precision where it matters — the slow accumulation of many small weight updates — while gaining speed where it does not.
This is why a model's optimizer states cost ~2–3× the model's parameters in memory: fp32 master weights plus two fp32 Adam moments. For a 7B model that is ~84GB of optimizer state alone, which is why distributed training (Chapter 18) shards it across devices.
15.6

Two complementary techniques extend what fits on a given GPU. Gradient checkpointing (introduced in Chapter 11) trades compute for activation memory; gradient accumulation trades steps for a larger effective batch size. Together they let you train models and batch sizes that would otherwise not fit.

Gradient Accumulation

To use a large batch size that does not fit in memory, split it into smaller micro-batches, accumulate their gradients, and step the optimizer only after processing all of them. The effective batch size is micro_batch × accumulation_steps, achieving the statistics of a large batch with the memory of a small one.

textGradient accumulation
effective_batch = micro_batch × accumulation_steps

for i in 1 ... accumulation_steps:
    loss = forward(micro_batch_i) / accumulation_steps
    loss.backward()                  # gradients ADD up
optimizer.step(); optimizer.zero_grad()   # one step per effective batch
PythonGradient accumulation + checkpointing
import torch

accum_steps = 8                  # effective batch = 8x micro-batch
opt.zero_grad()

for i, micro_batch in enumerate(dataloader):
    with torch.autocast('cuda', dtype=torch.bfloat16):
        loss = lm_loss(model, micro_batch.cuda()) / accum_steps  # scale!
    loss.backward()                      # accumulate gradients

    if (i + 1) % accum_steps == 0:       # time to step
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); opt.zero_grad()

# Note the /accum_steps: gradients ADD, so each micro-batch's loss
# must be divided so the accumulated gradient equals the full-batch mean.

# Enable checkpointing on the model to halve activation memory:
model.gradient_checkpointing_enable()    # ~33% more compute, big memory win
TechniqueTradesWhen to use
Mixed precision (bf16)Precision → speed + memoryAlways, on modern GPUs
Gradient checkpointingCompute → activation memoryWhen activations dominate memory
Gradient accumulationSteps → effective batch sizeWhen target batch won't fit
All three togetherStandard for large-model training
15.7

Batch size and learning rate are not independent. A larger batch gives a less noisy gradient estimate, which permits — and often requires — a larger learning rate. Getting their relationship right is central to efficient training, and the relationship has been studied extensively.

Scaling Rules

Two heuristics relate learning rate to batch size. The linear scaling rule (Goyal et al., 2017) says lr should scale linearly with batch size. The square-root rule says lr should scale with √(batch size), which better matches the noise structure of SGD. In practice, the square-root rule holds over a wider range for adaptive optimizers.

textBatch-size / learning-rate scaling rules
Linear rule:      lr ∝ batch_size
Square-root rule: lr ∝ √(batch_size)

# Both break down past a 'critical batch size' beyond which
# larger batches give diminishing returns (McCandlish et al., 2018).
Train Note: The Critical Batch Size
There is a regime where doubling the batch size lets you halve the number of steps for the same result — perfect scaling, ideal for parallelism. But beyond a 'critical batch size', larger batches stop helping: you do more compute per step without proportionally fewer steps.
The critical batch size grows as training progresses and as the model gets larger. This is why huge-batch training is feasible for large models — their critical batch size is large — and why measuring it informs how much parallelism is worthwhile (Chapter 18).

Practical Defaults

SettingSmall model (~100M)Large model (~10B+)
Peak learning rate3e-4 to 6e-41e-4 to 3e-4
Batch size (tokens)~0.5M2M to 16M
Warmup steps~1-2k~2-5k
Weight decay0.10.1
Adam betas(0.9, 0.95)(0.9, 0.95)
Grad clip1.01.0

Note the inverse relationship: larger models use smaller peak learning rates. This reflects the empirical finding (and theoretical work on μP, Chapter 10) that the optimal learning rate decreases as width grows. Hyperparameter transfer techniques aim to tune on a small model and transfer the settings to a large one, saving enormous compute.

15.8

At scale, training runs span weeks and cost millions of dollars, and they do not always go smoothly. Loss spikes — sudden jumps in the loss — are common, and full divergence to NaN can destroy a run. Knowing how to diagnose and recover is essential practical knowledge.

The Loss-Spike Taxonomy

SymptomLikely causeResponse
Single spike, recoversHard batch / rare dataUsually fine; clipping handled it
Spike, doesn't recoverLR too high for current stateRestart from checkpoint, lower LR
Slow divergenceAccumulating instabilityLower LR, check data, more warmup
Immediate NaNBad init / fp16 overflowCheck init, switch to bf16
Repeated spikesPathological data shardsSkip/clean the offending data
Loss plateau then NaNDead/saturated layersInspect activations, norms

The Skip-and-Restart Recipe

The standard recovery for an unrecoverable spike, used in the training of models like PaLM and OPT: roll back to the last good checkpoint, skip the batches that caused the spike, and resume — sometimes with a temporarily lower learning rate. Because the spike is often triggered by specific pathological data, skipping it usually lets training continue cleanly.

Train Note: Checkpoint Often
Large training runs checkpoint the full state — weights, optimizer moments, data position, RNG state, and step count — frequently (e.g. every few hundred steps). When a spike ruins the run, you roll back to the last checkpoint rather than starting over. At million-dollar scale, the checkpoint cadence is a direct lever on how much compute a failure costs.
The OPT-175B training logbook (Zhang et al., 2022) is a famous public record of how messy real large-model training is: dozens of restarts, hardware failures, loss spikes, and manual interventions over months. It is essential reading for anyone who imagines large-model training is a clean, automated process.
PythonA robust checkpoint
import torch

def save_checkpoint(path, model, opt, scheduler, step, rng_state):
    torch.save({
        'model':      model.state_dict(),
        'optimizer':  opt.state_dict(),      # Adam moments!
        'scheduler':  scheduler.state_dict(),
        'step':       step,                  # resume position
        'rng':        rng_state,            # reproducibility
    }, path)

# A checkpoint missing the optimizer state would restart Adam's moments
# from zero -- effectively re-triggering the warmup fragility of Section 15.2.
# Always save the FULL training state, not just the weights.
15.9

A training run is only as good as your ability to see inside it. Experienced practitioners watch a small set of signals that, together, reveal the health of the optimization. Logging these from the start turns debugging from guesswork into diagnosis.

SignalHealthy patternWarning sign
Training lossSmooth decreaseSpikes, plateaus, increases
Validation lossTracks train loss downDiverges up (overfitting)
Gradient normStable, modestSpikes or sustained climb
Learning rateFollows the schedule(sanity check the schedule)
Weight / update ratio~1e-3 per stepToo high: unstable; too low: stuck
Activation/grad statsStable distributionsExploding or vanishing
Tokens/sec throughputStableDrops signal hardware/IO issues
Train Note: The Update-to-Weight Ratio
A subtle but powerful diagnostic (popularized by Karpathy) is the ratio of update magnitude to weight magnitude: ‖lr·update‖ / ‖weight‖, per layer. A healthy value is around 1e-3 — each step changes weights by about 0.1%.
If the ratio is much higher, the learning rate is too aggressive for that layer; much lower, and the layer is barely learning. Watching it per-layer can reveal that, say, the embedding layer wants a different learning rate than the attention layers — motivating per-group learning rates.

Evaluation During Training

Beyond loss, periodically evaluate on held-out validation data and on downstream benchmarks. Validation perplexity confirms generalization; downstream task metrics (even simple ones) confirm that decreasing loss translates into useful capability. Chapter 21 covers evaluation during pretraining in depth.

15.10

We now assemble everything — the objective, schedule, optimizer, clipping, mixed precision, accumulation, and checkpointing — into a single training loop. This is the practical recipe for training a GPT-style model, and it is the loop that, scaled up, trains real LLMs.

PythonCode Lab: the complete training loop
import torch, math
import torch.nn as nn, torch.nn.functional as F

def train(model, dataloader, *, total_steps, peak_lr=3e-4,
          warmup=2000, accum=8, clip=1.0, wd=0.1):
    model = model.cuda(); model.gradient_checkpointing_enable()
    opt   = configure_optimizer(model, lr=peak_lr, wd=wd)

    step = 0; opt.zero_grad()
    for i, batch in enumerate(dataloader):
        # 1. Forward + loss in bf16, scaled for accumulation
        with torch.autocast('cuda', dtype=torch.bfloat16):
            loss = lm_loss(model, batch.cuda()) / accum
        # 2. Backward (gradients accumulate)
        loss.backward()

        if (i + 1) % accum == 0:
            # 3. Set learning rate from schedule
            lr = lr_schedule(step, peak_lr, warmup, total_steps)
            for g in opt.param_groups: g['lr'] = lr
            # 4. Clip, step, reset
            gnorm = nn.utils.clip_grad_norm_(model.parameters(), clip)
            opt.step(); opt.zero_grad()
            step += 1

            # 5. Monitor
            if step % 100 == 0:
                print(f"step {step}: loss={loss.item()*accum:.3f} "
                      f"lr={lr:.2e} gnorm={gnorm:.2f} ppl={math.exp(loss.item()*accum):.1f}")
            # 6. Checkpoint
            if step % 1000 == 0:
                save_checkpoint(f'ckpt_{step}.pt', model, opt, None, step, torch.get_rng_state())
            if step >= total_steps: break

# This loop -- with the architecture of Ch.13 and data of Ch.14 -- is
# a complete, working LLM pretraining pipeline. Scale it up and add
# distributed training (Ch.18) and you have the real thing.
You Can Now Train a Transformer
This loop brings together everything from the book so far: cross-entropy (Ch.4), numerical stability (Ch.5), AdamW (Ch.2), the architecture (Ch.13), tokenization (Ch.14), and the training techniques of this chapter. It is a real, working pretraining pipeline — the same structure used to train GPT and LLaMA, differing only in scale.
What remains for frontier-scale training is distribution across thousands of GPUs (Part IV) and the enormous data curation (Chapter 17). But the core training logic — the loop above — you now understand completely.
15.11

Training Recipe Quick-Reference

IngredientStandard choiceWhy
ObjectiveNext-token cross-entropySelf-supervised, scales perfectly
OptimizerAdamW (0.9, 0.95)Adaptive + decoupled decay
LR scheduleWarmup + cosine decayStability then fine convergence
Grad clippingGlobal norm 1.0Survives loss spikes
Precisionbf16 mixed2× speed, no loss scaling
MemoryCheckpointing + accumulationFit bigger models / batches
Weight decay0.1 on matrices onlyRegularize without harming norms

Exercises

Exercises 1–10 are pen-and-paper or derivations; 11–20 require code.

Exercise 1: Pen & Paper
Write the causal LM loss and explain the one-position shift between inputs and targets. Why does one forward pass yield T training signals?
Exercise 2: Pen & Paper
A model reaches a loss of 1.6 nats. Compute the perplexity. Interpret what perplexity 5 means about the model's next-token uncertainty.
Exercise 3: Pen & Paper
Explain why Transformers need warmup but plain SGD on a CNN often does not. Connect your answer to Adam's second-moment estimate.
Exercise 4: Derive
Write the warmup + cosine decay schedule as a piecewise function. Verify it is continuous at the warmup/decay boundary.
Exercise 5: Pen & Paper
Explain the difference between Adam+L2 and AdamW. Show why, in Adam+L2, parameters with large gradients receive less effective weight decay.
Exercise 6: Pen & Paper
Why is weight decay applied to weight matrices but not to LayerNorm gains or biases? What would decaying a LayerNorm gain toward zero do?
Exercise 7: Pen & Paper
Global-norm clipping scales all gradients by max_norm/g_norm when g_norm exceeds the threshold. Prove this preserves the gradient direction.
Exercise 8: Pen & Paper
Why has bf16 replaced fp16 for LLM training? Explain in terms of exponent range and the need (or not) for loss scaling.
Exercise 9: Pen & Paper
Derive why gradient accumulation requires dividing each micro-batch loss by accumulation_steps. What goes wrong if you forget the division?
Exercise 10: Pen & Paper
Explain the critical batch size. Why does training large models tolerate (and benefit from) very large batches?
Exercise 11: Code
Implement the warmup + cosine decay schedule. Plot the learning rate over 100k steps for warmup ∈ {500, 2000, 5000}. Describe the effect.
Exercise 12: Code
Implement configure_optimizer with decay/no-decay parameter groups. Print which parameters land in each group for a small Transformer.
Exercise 13: Code
Implement global-norm gradient clipping from scratch (without nn.utils). Verify it matches PyTorch's clip_grad_norm_ on random gradients.
Exercise 14: Code Lab
Train a small GPT on a text corpus WITHOUT warmup and WITH warmup. Plot both loss curves. Demonstrate that no-warmup diverges or trains worse.
Exercise 15: Code
Compare fp32, fp16 (with GradScaler), and bf16 training on the same model. Measure peak memory, tokens/sec, and final loss for each.
Exercise 16: Code
Implement gradient accumulation. Verify that micro_batch=4 with accum=4 produces nearly identical gradients to a single batch of 16.
Exercise 17: Code Lab
Measure the effect of gradient checkpointing: train a deep model with and without it, reporting peak memory and wall-clock time. Confirm the ~33% compute overhead.
Exercise 18: Code
Build the monitoring dashboard: log loss, perplexity, grad-norm, learning rate, and the update-to-weight ratio. Plot all five over a training run.
Exercise 19: Code
Implement save/load checkpoint including optimizer and RNG state. Verify that resuming from a checkpoint produces bit-identical continuation to an uninterrupted run.
Exercise 20: Code (Challenge)
Assemble the complete training loop from Section 15.10 and pretrain a small GPT (e.g. 10M params) on a corpus like TinyStories or Shakespeare. Achieve a coherent sample. Log all monitoring signals, demonstrate a checkpoint-resume, and write a short post-mortem of any instabilities you encountered.

Further reading: “Decoupled Weight Decay Regularization” (Loshchilov & Hutter, 2017) for AdamW. “Attention Is All You Need” (Vaswani et al., 2017) for the original warmup schedule. “An Empirical Model of Large-Batch Training” (McCandlish et al., 2018) for the critical batch size. “OPT: Open Pre-trained Transformer Language Models” (Zhang et al., 2022) and its training logbook for a candid account of real large-model training. Andrej Karpathy's nanoGPT for a complete, readable reference implementation of this chapter's loop.


Next → Chapter 16: Scaling Laws

You can now train a Transformer — but how big should it be, how much data does it need, and how much compute will it take? Chapter 16 answers these with scaling laws: the remarkably precise power-law relationships between model size, dataset size, compute, and loss. We will derive the Chinchilla compute-optimal recipe, understand the 6N rule for training FLOPs, and see how scaling laws turn the art of choosing model size into a predictable science — the quantitative foundation that guides every frontier training run. This chapter closes Part III.

20 Exercises in this chapter
Attempt each exercise before checking the worked solutions.
View Solutions →