Building a GPT from Scratch
With the architecture built (Chapter 13) and text tokenized (Chapter 14), we can finally train. The pretraining objective for a decoder-only Transformer is the simplest imaginable: given a sequence of tokens, predict the next one at every position. This single objective, scaled up, produces all the capabilities of modern LLMs.
The Loss
For a sequence x₁ ... x_T:
L = -(1/T) Σₜ log Pθ(xₜ₊₁ | x₁ ... xₜ)
# Cross-entropy (Chapter 4) at every position,
# averaged over the sequence. That is the entire objective.Because of the causal mask (Chapter 12), every position predicts its successor using only earlier tokens, and all T predictions are computed in a single forward pass. This is the efficiency that makes the decoder-only design so attractive: one forward pass yields T training signals, one per token.
Teacher Forcing and the Shift
During training we feed the true tokens and ask the model to predict each next token — this is teacher forcing. Implementation reduces to a one-position shift: the inputs are tokens x₁...x_{T-1} and the targets are x₂...x_T. The model's logits at position t are scored against the actual token at position t+1.
import torch; import torch.nn.functional as F
def lm_loss(model, tokens): # tokens: (B, T+1)
"""Causal LM loss via the one-position shift."""
inputs = tokens[:, :-1] # x_1 ... x_T
targets = tokens[:, 1:] # x_2 ... x_{T+1}
logits = model(inputs) # (B, T, V)
# Flatten to (B*T, V) and (B*T,) for cross-entropy
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
)
return loss
# Perplexity = exp(loss) is the standard human-readable metric (Chapter 4).
# A loss of 2.0 nats ≈ perplexity 7.4: the model is effectively choosing
# among ~7 equally-likely next tokens at each step.Transformers are unusually sensitive to the learning rate early in training. Starting at the full learning rate often causes immediate divergence. The fix — learning-rate warmup — ramps the rate up linearly from near zero over the first few thousand steps, then decays it. This single trick is essential for stable Transformer training.
Why Early Training Is Fragile
At initialization the model's predictions are random and the gradients are large and noisy. Adam's second-moment estimate (the running variance of gradients) has not yet stabilized, so early updates can be wildly miscalibrated. A large learning rate amplifies these bad early updates, pushing the model into a region from which it cannot recover. Warmup gives Adam's statistics time to settle before applying full-strength updates.
Warmup + Cosine Decay
The standard modern schedule combines linear warmup with cosine decay: ramp up linearly to a peak learning rate over the warmup steps, then decay smoothly following a cosine curve down to a small final rate. This schedule trains the model fast once warmed up, then anneals to a low rate for fine convergence.
if step < warmup:
lr = peak_lr · (step / warmup) # linear ramp up
else:
progress = (step - warmup) / (total - warmup)
lr = min_lr + 0.5(peak_lr - min_lr)(1 + cos(π · progress))import math
def lr_schedule(step, peak_lr, warmup, total, min_lr=0.0):
"""Linear warmup then cosine decay to min_lr."""
if step < warmup: # ramp up
return peak_lr * step / warmup
if step > total: # past the end
return min_lr
progress = (step - warmup) / (total - warmup) # 0 -> 1
cosine = 0.5 * (1 + math.cos(math.pi * progress)) # 1 -> 0
return min_lr + (peak_lr - min_lr) * cosine
# Typical large-model schedule
peak, warmup, total = 3e-4, 2000, 100000
for s in [0, 1000, 2000, 50000, 100000]:
print(f"step {s:>6}: lr = {lr_schedule(s, peak, warmup, total):.2e}")
# step 0: lr = 0.00e+00 (start from zero)
# step 1000: lr = 1.50e-04 (halfway through warmup)
# step 2000: lr = 3.00e-04 (peak, warmup complete)
# step 50000: lr = 1.51e-04 (cosine decay)
# step 100000: lr = 0.00e+00 (annealed to zero)Chapter 2 derived Adam; here we focus on AdamW, the variant used to train essentially every modern LLM. AdamW's key correction — decoupling weight decay from the gradient update — makes regularization behave correctly with adaptive learning rates.
The Adam Update, Recalled
m ← β₁ m + (1-β₁) g # 1st moment (momentum)
v ← β₂ v + (1-β₂) g² # 2nd moment (variance)
m̂ = m / (1-β₁ᵗ), v̂ = v / (1-β₂ᵗ) # bias correction
θ ← θ - lr · m̂ / (√v̂ + ε) # adaptive updateThe W in AdamW: Decoupled Weight Decay
Plain Adam with L2 regularization adds λθ to the gradient — but then this term gets divided by √v̂ along with everything else, so parameters with large gradients get less regularization. AdamW (Loshchilov & Hutter, 2017) fixes this by applying weight decay directly to the parameters, separately from the adaptive gradient step.
Adam + L2: g ← g + λθ; then adaptive step (decay gets scaled by 1/√v̂)
AdamW: θ ← θ - lr(m̂/(√v̂+ε) + λθ) # decay applied directly
# weight decay is NOT scaled by the adaptive denominator| Hyperparameter | Typical value | Role |
|---|---|---|
| β₁ | 0.9 | Momentum: smooths the gradient direction |
| β₂ | 0.95 (LLMs) / 0.999 | Variance: how fast the scale adapts |
| ε | 1e-8 | Numerical floor in the denominator |
| weight decay λ | 0.1 | Regularization strength (decoupled) |
| peak lr | 1e-4 to 6e-4 | Scaled down as models grow |
import torch
def configure_optimizer(model, lr=3e-4, wd=0.1):
decay, no_decay = [], []
for name, p in model.named_parameters():
if not p.requires_grad: continue
if p.ndim >= 2: # matrices: decay
decay.append(p)
else: # biases, norms: no decay
no_decay.append(p)
groups = [
{'params': decay, 'weight_decay': wd},
{'params': no_decay, 'weight_decay': 0.0},
]
return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), eps=1e-8)
# betas=(0.9, 0.95) is the LLM default -- lower β₂ than the 0.999 used
# for smaller models, giving faster adaptation to changing gradient scales.Even with warmup and a good optimizer, training can produce occasional huge gradients — from a difficult batch, a rare token, or accumulated instability. A single enormous update can undo thousands of steps of progress or send the loss to NaN. Gradient clipping caps the gradient's magnitude, keeping every update bounded.
Clipping by Global Norm
The standard method clips by global norm: compute the L2 norm of all gradients concatenated together, and if it exceeds a threshold, scale every gradient down by the same factor so the total norm equals the threshold. This preserves the gradient's direction while bounding its size.
g_norm = √(Σ_all_params ‖gᵢ‖²) # total gradient norm
if g_norm > max_norm:
gᵢ ← gᵢ · (max_norm / g_norm) for all i # scale down, keep directionimport torch.nn as nn
# Standard placement: after backward, before optimizer.step
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# clip_grad_norm_ RETURNS the pre-clip norm -- log it to monitor training
# Healthy: grad_norm hovers around a stable value (e.g. 0.5-2.0)
# Warning: frequent clipping (grad_norm >> max_norm) signals instability
# Spike: a sudden 100x grad_norm is a loss spike being clipped awayChapter 5 introduced the floating-point formats and the mechanics of mixed precision. Here we apply them to Transformer training. The goal: do the bulk of computation in 16-bit precision (bf16 or fp16) for speed and memory, while keeping a master copy of the weights and the optimizer states in 32-bit for accuracy.
bf16 vs fp16 for Training
| fp16 (half) | bf16 (bfloat16) |
|---|---|
| 10-bit mantissa, 5-bit exponent | 7-bit mantissa, 8-bit exponent |
| More precision, less range | Less precision, full fp32 range |
| Overflows easily (max ~65504) | Rarely overflows (max ~3.4e38) |
| NEEDS loss scaling | No loss scaling needed |
| Older GPUs (V100) | Modern GPUs (A100, H100, TPU) |
| Fiddly to stabilize | Drop-in, robust — the modern default |
For LLM training, bf16 has decisively won. Its full fp32 exponent range means gradients and activations almost never overflow, eliminating the loss-scaling machinery that fp16 requires. The cost — fewer mantissa bits, so less precision — turns out not to matter for training, where range is more important than precision.
import torch
model = model.cuda()
opt = configure_optimizer(model, lr=3e-4)
for tokens in dataloader:
tokens = tokens.cuda()
opt.zero_grad()
# Autocast: forward + loss in bf16, no scaler needed
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
loss = lm_loss(model, tokens)
loss.backward() # gradients accumulate in fp32
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
# Compare to fp16, which additionally requires:
# scaler = torch.cuda.amp.GradScaler()
# scaler.scale(loss).backward(); scaler.unscale_(opt)
# scaler.step(opt); scaler.update()
# bf16's full range makes all of that unnecessary.Two complementary techniques extend what fits on a given GPU. Gradient checkpointing (introduced in Chapter 11) trades compute for activation memory; gradient accumulation trades steps for a larger effective batch size. Together they let you train models and batch sizes that would otherwise not fit.
Gradient Accumulation
To use a large batch size that does not fit in memory, split it into smaller micro-batches, accumulate their gradients, and step the optimizer only after processing all of them. The effective batch size is micro_batch × accumulation_steps, achieving the statistics of a large batch with the memory of a small one.
effective_batch = micro_batch × accumulation_steps
for i in 1 ... accumulation_steps:
loss = forward(micro_batch_i) / accumulation_steps
loss.backward() # gradients ADD up
optimizer.step(); optimizer.zero_grad() # one step per effective batchimport torch
accum_steps = 8 # effective batch = 8x micro-batch
opt.zero_grad()
for i, micro_batch in enumerate(dataloader):
with torch.autocast('cuda', dtype=torch.bfloat16):
loss = lm_loss(model, micro_batch.cuda()) / accum_steps # scale!
loss.backward() # accumulate gradients
if (i + 1) % accum_steps == 0: # time to step
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step(); opt.zero_grad()
# Note the /accum_steps: gradients ADD, so each micro-batch's loss
# must be divided so the accumulated gradient equals the full-batch mean.
# Enable checkpointing on the model to halve activation memory:
model.gradient_checkpointing_enable() # ~33% more compute, big memory win| Technique | Trades | When to use |
|---|---|---|
| Mixed precision (bf16) | Precision → speed + memory | Always, on modern GPUs |
| Gradient checkpointing | Compute → activation memory | When activations dominate memory |
| Gradient accumulation | Steps → effective batch size | When target batch won't fit |
| All three together | — | Standard for large-model training |
Batch size and learning rate are not independent. A larger batch gives a less noisy gradient estimate, which permits — and often requires — a larger learning rate. Getting their relationship right is central to efficient training, and the relationship has been studied extensively.
Scaling Rules
Two heuristics relate learning rate to batch size. The linear scaling rule (Goyal et al., 2017) says lr should scale linearly with batch size. The square-root rule says lr should scale with √(batch size), which better matches the noise structure of SGD. In practice, the square-root rule holds over a wider range for adaptive optimizers.
Linear rule: lr ∝ batch_size
Square-root rule: lr ∝ √(batch_size)
# Both break down past a 'critical batch size' beyond which
# larger batches give diminishing returns (McCandlish et al., 2018).Practical Defaults
| Setting | Small model (~100M) | Large model (~10B+) |
|---|---|---|
| Peak learning rate | 3e-4 to 6e-4 | 1e-4 to 3e-4 |
| Batch size (tokens) | ~0.5M | 2M to 16M |
| Warmup steps | ~1-2k | ~2-5k |
| Weight decay | 0.1 | 0.1 |
| Adam betas | (0.9, 0.95) | (0.9, 0.95) |
| Grad clip | 1.0 | 1.0 |
Note the inverse relationship: larger models use smaller peak learning rates. This reflects the empirical finding (and theoretical work on μP, Chapter 10) that the optimal learning rate decreases as width grows. Hyperparameter transfer techniques aim to tune on a small model and transfer the settings to a large one, saving enormous compute.
At scale, training runs span weeks and cost millions of dollars, and they do not always go smoothly. Loss spikes — sudden jumps in the loss — are common, and full divergence to NaN can destroy a run. Knowing how to diagnose and recover is essential practical knowledge.
The Loss-Spike Taxonomy
| Symptom | Likely cause | Response |
|---|---|---|
| Single spike, recovers | Hard batch / rare data | Usually fine; clipping handled it |
| Spike, doesn't recover | LR too high for current state | Restart from checkpoint, lower LR |
| Slow divergence | Accumulating instability | Lower LR, check data, more warmup |
| Immediate NaN | Bad init / fp16 overflow | Check init, switch to bf16 |
| Repeated spikes | Pathological data shards | Skip/clean the offending data |
| Loss plateau then NaN | Dead/saturated layers | Inspect activations, norms |
The Skip-and-Restart Recipe
The standard recovery for an unrecoverable spike, used in the training of models like PaLM and OPT: roll back to the last good checkpoint, skip the batches that caused the spike, and resume — sometimes with a temporarily lower learning rate. Because the spike is often triggered by specific pathological data, skipping it usually lets training continue cleanly.
import torch
def save_checkpoint(path, model, opt, scheduler, step, rng_state):
torch.save({
'model': model.state_dict(),
'optimizer': opt.state_dict(), # Adam moments!
'scheduler': scheduler.state_dict(),
'step': step, # resume position
'rng': rng_state, # reproducibility
}, path)
# A checkpoint missing the optimizer state would restart Adam's moments
# from zero -- effectively re-triggering the warmup fragility of Section 15.2.
# Always save the FULL training state, not just the weights.A training run is only as good as your ability to see inside it. Experienced practitioners watch a small set of signals that, together, reveal the health of the optimization. Logging these from the start turns debugging from guesswork into diagnosis.
| Signal | Healthy pattern | Warning sign |
|---|---|---|
| Training loss | Smooth decrease | Spikes, plateaus, increases |
| Validation loss | Tracks train loss down | Diverges up (overfitting) |
| Gradient norm | Stable, modest | Spikes or sustained climb |
| Learning rate | Follows the schedule | (sanity check the schedule) |
| Weight / update ratio | ~1e-3 per step | Too high: unstable; too low: stuck |
| Activation/grad stats | Stable distributions | Exploding or vanishing |
| Tokens/sec throughput | Stable | Drops signal hardware/IO issues |
Evaluation During Training
Beyond loss, periodically evaluate on held-out validation data and on downstream benchmarks. Validation perplexity confirms generalization; downstream task metrics (even simple ones) confirm that decreasing loss translates into useful capability. Chapter 21 covers evaluation during pretraining in depth.
We now assemble everything — the objective, schedule, optimizer, clipping, mixed precision, accumulation, and checkpointing — into a single training loop. This is the practical recipe for training a GPT-style model, and it is the loop that, scaled up, trains real LLMs.
import torch, math
import torch.nn as nn, torch.nn.functional as F
def train(model, dataloader, *, total_steps, peak_lr=3e-4,
warmup=2000, accum=8, clip=1.0, wd=0.1):
model = model.cuda(); model.gradient_checkpointing_enable()
opt = configure_optimizer(model, lr=peak_lr, wd=wd)
step = 0; opt.zero_grad()
for i, batch in enumerate(dataloader):
# 1. Forward + loss in bf16, scaled for accumulation
with torch.autocast('cuda', dtype=torch.bfloat16):
loss = lm_loss(model, batch.cuda()) / accum
# 2. Backward (gradients accumulate)
loss.backward()
if (i + 1) % accum == 0:
# 3. Set learning rate from schedule
lr = lr_schedule(step, peak_lr, warmup, total_steps)
for g in opt.param_groups: g['lr'] = lr
# 4. Clip, step, reset
gnorm = nn.utils.clip_grad_norm_(model.parameters(), clip)
opt.step(); opt.zero_grad()
step += 1
# 5. Monitor
if step % 100 == 0:
print(f"step {step}: loss={loss.item()*accum:.3f} "
f"lr={lr:.2e} gnorm={gnorm:.2f} ppl={math.exp(loss.item()*accum):.1f}")
# 6. Checkpoint
if step % 1000 == 0:
save_checkpoint(f'ckpt_{step}.pt', model, opt, None, step, torch.get_rng_state())
if step >= total_steps: break
# This loop -- with the architecture of Ch.13 and data of Ch.14 -- is
# a complete, working LLM pretraining pipeline. Scale it up and add
# distributed training (Ch.18) and you have the real thing.Training Recipe Quick-Reference
| Ingredient | Standard choice | Why |
|---|---|---|
| Objective | Next-token cross-entropy | Self-supervised, scales perfectly |
| Optimizer | AdamW (0.9, 0.95) | Adaptive + decoupled decay |
| LR schedule | Warmup + cosine decay | Stability then fine convergence |
| Grad clipping | Global norm 1.0 | Survives loss spikes |
| Precision | bf16 mixed | 2× speed, no loss scaling |
| Memory | Checkpointing + accumulation | Fit bigger models / batches |
| Weight decay | 0.1 on matrices only | Regularize without harming norms |
Exercises
Exercises 1–10 are pen-and-paper or derivations; 11–20 require code.
Further reading: “Decoupled Weight Decay Regularization” (Loshchilov & Hutter, 2017) for AdamW. “Attention Is All You Need” (Vaswani et al., 2017) for the original warmup schedule. “An Empirical Model of Large-Batch Training” (McCandlish et al., 2018) for the critical batch size. “OPT: Open Pre-trained Transformer Language Models” (Zhang et al., 2022) and its training logbook for a candid account of real large-model training. Andrej Karpathy's nanoGPT for a complete, readable reference implementation of this chapter's loop.
Next → Chapter 16: Scaling Laws
You can now train a Transformer — but how big should it be, how much data does it need, and how much compute will it take? Chapter 16 answers these with scaling laws: the remarkably precise power-law relationships between model size, dataset size, compute, and loss. We will derive the Chinchilla compute-optimal recipe, understand the 6N rule for training FLOPs, and see how scaling laws turn the art of choosing model size into a predictable science — the quantitative foundation that guides every frontier training run. This chapter closes Part III.