Part III: The Transformer
Chapter 12

The Transformer Architecture

Encoders, decoders, and the full architecture
20 Exercises
12.1

In Chapter 9 we met attention as the fix for the seq2seq bottleneck: instead of compressing a whole sentence into one vector, let the decoder look back at every encoder state and take a weighted average. This chapter develops that idea into the central mechanism of the Transformer. The key reframing: attention is a soft, differentiable dictionary lookup.

The Dictionary Analogy

A Python dictionary maps keys to values: you supply a query key, find the matching stored key, and retrieve its value. This is a hard lookup — exactly one key matches. Attention softens this: the query is compared against all keys, producing a similarity score for each, and the output is a weighted blend of all values, weighted by how well each key matched.

textHard lookup vs soft attention
Hard (dict):   value = D[query]            # exact match, one value

Soft (attn):   scores  = [sim(query, k) for k in keys]
               weights = softmax(scores)
               output  = Σᵢ weightsᵢ · valueᵢ   # weighted blend of ALL values
Intuition: Why 'Soft' Lookup Is the Whole Point
A hard lookup is not differentiable — you cannot take a gradient through 'find the exact match.' A soft lookup, weighting all values by a softmax of similarities, is smooth and differentiable everywhere. The model can learn what to attend to via gradient descent.
This is the same trick we have seen repeatedly: replace a hard, non-differentiable decision (argmax, exact match) with a soft, differentiable approximation (softmax, weighted average). Differentiability is what makes learning possible.

Three Roles: Query, Key, Value

Query (Q)
What I am looking for. Derived from the current position; used to score all keys.
Key (K)
What I contain, for matching. Each position exposes a key that queries are compared against.
Value (V)
What I return if matched. The actual information retrieved, weighted by the attention scores.

The separation of keys from values is subtle but crucial. A position is matched on its key but contributes its value. This lets the model learn one representation for 'how relevant am I?' (the key) and a separate representation for 'what information do I carry?' (the value).

12.2

Chapter 9 introduced Bahdanau (additive) attention, which scores query-key compatibility with a small neural network. The Transformer replaced this with a simple scaled dot product. Tracing this evolution shows why the change mattered.

Additive (Bahdanau, 2014)Scaled dot-product (Vaswani, 2017)
score = vᵀ tanh(W₁q + W₂k)score = q·k / √d
A small MLP per query-key pairA single dot product
More parameters (v, W₁, W₂)No extra parameters in the score
Slower: cannot batch as one matmulOne big matmul QKᵀ — GPU-friendly
Works with mismatched dimensionsRequires equal query/key dimension
Marginally better at tiny scaleVastly faster; wins at scale

The dot product's decisive advantage is speed: scoring all queries against all keys is a single matrix multiplication QKᵀ, which GPUs execute extremely efficiently. Bahdanau's per-pair MLP cannot be expressed as one matmul, making it far slower at scale. When the goal is to train on billions of tokens, this efficiency is everything.

History: The Title Said It All
Vaswani et al. titled their 2017 paper 'Attention Is All You Need.' The provocative claim: discard recurrence and convolution entirely, keep only attention. The dot-product formulation made this practical — it was the efficiency that let attention replace, rather than merely augment, the RNN.
Within three years, this single architectural bet reshaped all of NLP and then spread to vision, audio, biology, and beyond. Few paper titles have aged so well.
12.3

Here is the equation that defines the Transformer. Everything else — multi-head, masking, positional encoding — is built around it. Take the time to understand each piece; it repays the investment many times over.

Attention(Q, K, V) = softmax(QKᵀ / √d_k) V

Reading the Equation Piece by Piece

TermShapeMeaning
Q(T, d_k)Queries: one per position, what each seeks
K(T, d_k)Keys: one per position, what each offers
V(T, d_v)Values: one per position, what each carries
QKᵀ(T, T)Raw scores: similarity of every query to every key
/ √d_k(T, T)Scaled scores: prevents softmax saturation
softmax(T, T)Attention weights: each row sums to 1
· V(T, d_v)Output: weighted blend of values per position

Why Divide by √d_k?

The scaling factor √d_k is not cosmetic — it is essential for stable training. Consider the dot product of two random d_k-dimensional vectors with unit-variance components. The dot product is a sum of d_k products, so its variance grows linearly with d_k. For large d_k, the scores become large in magnitude, pushing the softmax into a saturated regime where it is nearly one-hot — and where gradients vanish.

textThe variance argument for √d_k
q, k have d_k components, each ~ N(0, 1), independent
q·k = Σᵢ qᵢ kᵢ
Var(q·k) = Σᵢ Var(qᵢ kᵢ) = d_k        # variance grows with d_k

Dividing by √d_k:  Var(q·k / √d_k) = 1   # restored to unit variance

[Missing Component: attnNote]

PythonScaled dot-product attention from scratch
import numpy as np

def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)  # stable
    e = np.exp(x); return e / e.sum(axis=axis, keepdims=True)

def attention(Q, K, V, mask=None):
    """Q,K: (T, d_k)   V: (T, d_v)   ->  output (T, d_v)."""
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)       # (T, T) scaled scores
    if mask is not None:
        scores = np.where(mask, scores, -1e9)  # -inf before softmax
    weights = softmax(scores, axis=-1)     # (T, T) each row sums to 1
    return weights @ V, weights           # (T, d_v), plus weights for viz

# Toy example: 4 tokens, d_k = d_v = 8
np.random.seed(0)
T, d = 4, 8
Q = np.random.randn(T, d); K = np.random.randn(T, d); V = np.random.randn(T, d)
out, w = attention(Q, K, V)
print(f"output shape: {out.shape}")   # (4, 8)
print(f"weights rows sum to: {w.sum(1)}")  # [1. 1. 1. 1.]

Shape Trace: Scaled dot-product attention (T=4, d_k=d_v=8)

OperationShapeNote
Q, K, V(4, 8)one row per token
Q @ K.T(4, 4)query×key score matrix
/ √d_k(4, 4)scaled, unit variance
softmax (rows)(4, 4)attention weights, rows sum to 1
weights @ V(4, 8)output: blended values
12.4

In Chapter 9's seq2seq attention, queries came from the decoder and keys/values from the encoder — this is cross-attention. The Transformer's signature innovation is self-attention: queries, keys, and values are all linear projections of the same input sequence. Every position attends to every other position in the same sequence, letting each token gather context from all others.

The Three Projections

Self-attention learns three weight matrices that project the input X into queries, keys, and values:

textSelf-attention projections
Q = X W_Q      # (T, d) × (d, d_k) = (T, d_k)
K = X W_K      # (T, d) × (d, d_k) = (T, d_k)
V = X W_V      # (T, d) × (d, d_v) = (T, d_v)

out = Attention(Q, K, V) = softmax(QKᵀ/√d_k) V

The same input X generates all three through different learned projections. W_Q learns 'what should each token look for', W_K learns 'what should each token advertise', and W_V learns 'what should each token contribute'. The model discovers these roles entirely through gradient descent on the language-modeling loss.

PythonSelf-attention layer from scratch
import numpy as np

class SelfAttention:
    def __init__(self, d_model, d_k, seed=0):
        rng = np.random.default_rng(seed); s = 1/np.sqrt(d_model)
        self.W_Q = rng.normal(0, s, (d_model, d_k))
        self.W_K = rng.normal(0, s, (d_model, d_k))
        self.W_V = rng.normal(0, s, (d_model, d_k))
        self.d_k = d_k

    def forward(self, X, mask=None):  # X: (T, d_model)
        Q = X @ self.W_Q                 # (T, d_k)
        K = X @ self.W_K                 # (T, d_k)
        V = X @ self.W_V                 # (T, d_k)
        scores = Q @ K.T / np.sqrt(self.d_k)  # (T, T)
        if mask is not None: scores = np.where(mask, scores, -1e9)
        w = softmax(scores, axis=-1)
        return w @ V, w

# Each token's output is a context-aware blend of all tokens' values.
# This is how 'bank' near 'river' differs from 'bank' near 'money':
# self-attention mixes in the surrounding context.

[Missing Component: attnNote]

12.5

For autoregressive language modeling — predicting the next token from previous ones — a position must not attend to future positions. If it could, predicting token t would trivially use token t itself, and the model would learn nothing useful. Causal masking enforces this by setting future scores to −∞ before the softmax.

The Causal Mask

A causal (or 'look-ahead') mask is a lower-triangular matrix: position i may attend to positions j ≤ i only. Masked entries are set to −∞ so that softmax assigns them exactly zero weight.

textCausal masking
mask[i][j] = True  if j ≤ i  else False     # lower triangular

scores[i][j] = -∞   wherever mask[i][j] is False
⇒ softmax gives 0 weight to all future positions

Here is what a causal attention pattern looks like for a 5-token sequence. Each row is a query position; each column a key position. Notice the strictly lower-triangular structure — position 1 attends only to itself, position 5 attends to all five:

[Missing Component: attnMatrix]

PythonCausal masking in practice
import numpy as np

def causal_mask(T):
    """Lower-triangular boolean mask: position i sees positions <= i."""
    return np.tril(np.ones((T, T), dtype=bool))

mask = causal_mask(5)
print(mask.astype(int))
# [[1 0 0 0 0]
#  [1 1 0 0 0]
#  [1 1 1 0 0]
#  [1 1 1 1 0]
#  [1 1 1 1 1]]

# Applied inside attention: scores = where(mask, scores, -1e9)
# After softmax, every masked (future) position gets exactly 0 weight.

# This single mask is the ONLY structural difference between a
# GPT-style decoder (causal) and a BERT-style encoder (no mask).
ML Connection: Encoder vs Decoder Attention
The presence or absence of the causal mask is the defining difference between the two great Transformer families. BERT-style encoders use UNMASKED (bidirectional) self-attention — every token sees every other, ideal for understanding tasks. GPT-style decoders use CAUSAL (masked) self-attention — each token sees only the past, required for generation.
Encoder-decoder models like the original Transformer and T5 use both: a bidirectional encoder, a causal decoder, plus cross-attention connecting them. The same attention equation serves all three, differing only in the mask.
12.6

A single attention operation produces one weighted blend per position — one 'view' of the context. But a token may need to attend to different things for different reasons: syntactic dependencies, coreference, topical relatedness. Multi-head attention runs several attention operations in parallel, each in its own learned subspace, then concatenates the results.

The Multi-Head Mechanism

textMulti-head attention
For each head h = 1 ... H:
    Q_h = X W_Qʰ,   K_h = X W_Kʰ,   V_h = X W_Vʰ   # each d → d/H
    head_h = Attention(Q_h, K_h, V_h)              # (T, d/H)

MultiHead(X) = Concat(head₁, ..., head_H) W_O      # (T, d)

Crucially, each head operates in a lower-dimensional subspace of size d/H. With d = 512 and H = 8 heads, each head works in 64 dimensions. The total compute is roughly the same as one full-dimensional attention, but the model gains H independent 'perspectives' on the sequence. A final projection W_O mixes the concatenated heads back into the model dimension.

Intuition: Heads as Specialists
Think of the heads as a committee of specialists, each focusing on a different relationship. Interpretability research has found heads that track subject-verb agreement, heads that resolve pronouns to their antecedents, heads that attend to the previous token, and 'induction heads' that copy patterns — each emerging without explicit supervision.
No one tells a head what to specialize in; the division of labor emerges from training. This is one of the most striking examples of useful structure arising spontaneously from gradient descent on a simple objective.
PythonMulti-head attention from scratch
import numpy as np

class MultiHeadAttention:
    def __init__(self, d_model, n_heads, seed=0):
        assert d_model % n_heads == 0
        self.h   = n_heads
        self.d_k = d_model // n_heads       # per-head dimension
        rng = np.random.default_rng(seed); s = 1/np.sqrt(d_model)
        # One big matrix per projection; heads are slices
        self.W_Q = rng.normal(0, s, (d_model, d_model))
        self.W_K = rng.normal(0, s, (d_model, d_model))
        self.W_V = rng.normal(0, s, (d_model, d_model))
        self.W_O = rng.normal(0, s, (d_model, d_model))

    def _split(self, x, T):  # (T, d) -> (h, T, d_k)
        return x.reshape(T, self.h, self.d_k).transpose(1, 0, 2)

    def forward(self, X, mask=None):  # X: (T, d_model)
        T = X.shape[0]
        Q = self._split(X @ self.W_Q, T)    # (h, T, d_k)
        K = self._split(X @ self.W_K, T)
        V = self._split(X @ self.W_V, T)

        scores = Q @ K.transpose(0,2,1) / np.sqrt(self.d_k)  # (h, T, T)
        if mask is not None: scores = np.where(mask, scores, -1e9)
        w = softmax(scores, axis=-1)        # (h, T, T)
        heads = w @ V                       # (h, T, d_k)

        # Concatenate heads back to (T, d_model)
        concat = heads.transpose(1,0,2).reshape(T, -1)
        return concat @ self.W_O          # (T, d_model)

Shape Trace: Multi-head attention (T=10, d=512, H=8)

OperationShapeNote
input X(10, 512)10 tokens, model dim 512
X @ W_Q(10, 512)queries (all heads)
split into heads(8, 10, 64)8 heads, d_k = 64 each
Q @ Kᵀ / √d_k(8, 10, 10)per-head score matrices
softmax · V(8, 10, 64)per-head outputs
concat heads(10, 512)back to model dim
@ W_O(10, 512)output projection
12.7

Self-attention draws Q, K, and V from one sequence. Cross-attention draws queries from one sequence and keys/values from another. It is how a decoder attends to an encoder's output — the mechanism in the original Transformer, in T5, and in modern multimodal models where text queries attend to image features.

textCross-attention
Q = X_target W_Q      # queries from the target/decoder sequence
K = X_source W_K      # keys from the source/encoder sequence
V = X_source W_V      # values from the source/encoder sequence

out = Attention(Q, K, V)   # target attends to source
Self-attentionCross-attention
Q, K, V from the SAME sequenceQ from target, K/V from source
Each token attends within its sequenceTarget tokens attend to source tokens
Score matrix is (T, T)Score matrix is (T_target, T_source)
Used in encoders and decodersConnects decoder to encoder
Models intra-sequence structureModels cross-sequence alignment
Example: GPT, BERT layersExample: translation, image captioning
ML Connection: Cross-Attention in Multimodal Models
Cross-attention is the standard way to fuse modalities. In a vision-language model, text tokens (queries) attend to image patch embeddings (keys/values), letting the language model 'look at' relevant regions of the image while generating a description.
Flamingo, BLIP-2, and many video models use cross-attention to inject visual information into a language model. You will meet this again in Chapter 30 on multimodal LLMs.
12.8

Because attention weights are an explicit (T, T) matrix, they are tempting to interpret as the model's 'reasoning'. Interpretability research has uncovered genuinely meaningful patterns — but also cautions against reading too much into raw attention weights.

Documented Head Specializations

Head typePattern
Previous-token headAttends to the immediately preceding token (i → i-1)
Positional headAttends to a fixed relative offset
Syntactic headAttends along dependency-parse edges (verb → subject)
Coreference headResolves pronouns to their antecedents
Induction headCompletes patterns: ...[A][B]...[A] → predicts [B]
Rare-token / delimiterAttends to delimiters, sentence boundaries, or rare tokens

Induction Heads: The Engine of In-Context Learning

Induction heads (Olsson et al., 2022) are among the most important discoveries in mechanistic interpretability. They implement a copying rule: if the sequence contains '...[A][B]...[A]', the induction head attends from the second [A] back to the [B] that followed the first [A], predicting [B] again. This simple pattern-completion mechanism is believed to underlie much of the in-context learning ability of LLMs.

[Missing Component: attnNote]

⚠️
Attention Is Not Explanation (Necessarily)
It is tempting to treat high attention weight as 'the model used this token.' But Jain & Wallace (2019) showed that attention weights can often be altered substantially without changing the output — so they are not always a faithful explanation of the model's computation.
Attention weights are one signal among many. Faithful interpretability requires causal interventions (ablations, patching), not just reading off the softmax. Treat attention maps as suggestive, not definitive.
12.9

Attention's power comes at a price: computing QKᵀ produces a (T, T) matrix, so both compute and memory scale as O(T²) in the sequence length. For T = 1,000 this is a million entries; for T = 100,000 it is ten billion. This quadratic scaling is the central bottleneck of long-context Transformers.

textAttention complexity
QKᵀ:        O(T² · d)   compute,   O(T²)   memory
softmax·V:   O(T² · d)   compute

Total:      O(T² d)   — quadratic in sequence length T
ApproachComplexityIdea
Vanilla attentionO(T²)Exact; the baseline
FlashAttentionO(T²) compute, O(T) memTiling + recompute; no approximation
Sparse attentionO(T√T) or O(T log T)Attend to a structured subset
Linear attentionO(T)Kernel feature maps; approximate softmax
Sliding windowO(T·w)Each token attends to w neighbors
Multi-query / GQAO(T²), less memoryShare keys/values across heads
ML Connection: FlashAttention Changed Everything
FlashAttention (Dao et al., 2022) does not reduce the O(T²) compute, but it reduces memory to O(T) by never materializing the full (T,T) score matrix — it computes attention in tiles that fit in fast on-chip SRAM, recomputing as needed. This made training on much longer sequences practical and is now the default in essentially every serious Transformer implementation.
The lesson echoes Chapter 11's gradient checkpointing: sometimes the win is not fewer FLOPs but better use of the memory hierarchy. Chapter 27 covers inference optimization, including the KV-cache that exploits causal masking to avoid recomputing keys and values during generation.
12.10

In Chapter 10 you built every part of the Transformer block except one: the box labeled 'Multi-Head Attention.' You have now built that box. Here is the complete Pre-LN block again, with the attention sublayer fully specified.

Arch Stack: Pre-LN Transformer Block (attention now complete)

+ residual addx + FFN_out
Feed-Forward (SwiGLU)d → 4d → d
LayerNorm / RMSNormnormalize
+ residual addx + Attn_out
Multi-Head Self-AttentionH heads, causal mask
LayerNorm / RMSNormnormalize
input x(B, T, d)

The block computes two residual updates: x ← x + MHA(LN(x)) then x ← x + FFN(LN(x)). Attention lets each token gather context from the whole sequence; the feed-forward network then processes each position independently. Stacking dozens of these blocks gives the full Transformer.

Attention Mixes, FFN Processes
A useful mental model: attention is the only operation in the block that moves information BETWEEN positions — it mixes tokens. The feed-forward network operates on each position independently, processing the mixed information but never crossing positions.
This alternation — mix across positions (attention), then process each position (FFN), repeat — is the fundamental rhythm of the Transformer. Everything else (normalization, residuals, masking) is in service of making this two-step dance trainable at depth.
12.11

Attention Quick-Reference

ConceptFormula / shapeKey point
Scaled dot-productsoftmax(QKᵀ/√d)V√d keeps softmax unsaturated
Self-attentionQ,K,V = XW_Q, XW_K, XW_VContextual mixing within a sequence
Cross-attentionQ from target, K/V from sourceConnects two sequences
Causal masktril; future → -∞Required for autoregressive LMs
Multi-headH heads of dim d/HParallel views; specialized heads
ComplexityO(T² d)Quadratic in sequence length
Output projectionConcat(heads) W_OMixes head outputs

Exercises

Exercises 1–10 are pen-and-paper or derivations; 11–20 require code.

Exercise 1: Pen & Paper
Explain attention as a soft dictionary lookup. What makes the soft version differentiable where a hard lookup is not?
Exercise 2: Derive
Show that for q, k with d independent unit-variance components, Var(q·k) = d. Use this to justify the √d scaling in scaled dot-product attention.
Exercise 3: Pen & Paper
Without the √d scaling, what happens to the softmax for large d? Describe the effect on the attention distribution and on gradients.
Exercise 4: Pen & Paper
Why does attention separate keys from values? Give a concrete example where the optimal 'matching' representation differs from the optimal 'content' representation.
Exercise 5: Pen & Paper
Write the causal mask for T=4 explicitly. Show how setting masked scores to -∞ makes their softmax weight exactly zero.
Exercise 6: Pen & Paper
Compare the score-matrix shapes for self-attention vs cross-attention with a 10-token target and 15-token source. Explain the difference.
Exercise 7: Derive
For multi-head attention with model dim d and H heads, show that splitting into H heads of dim d/H keeps the total compute roughly equal to single-head attention at dim d.
Exercise 8: Pen & Paper
Describe what an induction head computes and why it is thought to underlie in-context learning. Sketch the attention pattern for the sequence A B C A.
Exercise 9: Pen & Paper
Attention is O(T²). For T = 2,048 and T = 32,768, compute the ratio of attention FLOPs. Why does this make long context expensive?
Exercise 10: Pen & Paper
Explain the difference between an encoder (unmasked) and decoder (causal) in terms of the single line of code that differs. Why does generation require the mask?
Exercise 11: Code
Implement scaled dot-product attention from scratch. Verify the output shape and that each attention row sums to 1. Compare with and without √d scaling on d=64 random inputs — measure the entropy of the attention distribution.
Exercise 12: Code
Implement self-attention as a class. Feed it a sequence where two tokens are identical except for context; show their output representations differ.
Exercise 13: Code
Implement causal masking and visualize the resulting attention matrix as a heatmap for a 10-token sequence. Confirm the lower-triangular structure.
Exercise 14: Code Lab
Implement multi-head attention from scratch with the head-splitting reshape. Verify shapes at every step against the shape trace in Section 12.6. Test with d=512, H=8, T=16.
Exercise 15: Code
Implement cross-attention. Build a toy translation setup where a 5-token target attends to a 7-token source; visualize the (5,7) alignment matrix.
Exercise 16: Code
Compare attention against PyTorch's torch.nn.functional.scaled_dot_product_attention. Verify your from-scratch implementation matches to 1e-5.
Exercise 17: Code Lab
Load a small pretrained Transformer (e.g., distilgpt2 via transformers). Extract and visualize attention maps from several heads for a sentence. Identify a previous-token head and a delimiter head.
Exercise 18: Code
Implement the KV-cache optimization for causal self-attention: during autoregressive generation, cache keys and values so each new token costs O(T) rather than O(T²). Measure the speedup.
Exercise 19: Code
Measure attention's quadratic scaling empirically: time a forward pass for T = 128, 256, 512, 1024, 2048 and plot time vs T. Confirm the O(T²) trend.
Exercise 20: Code (Challenge)
Implement a complete multi-head self-attention layer with backward pass using your Chapter 11 autograd engine. Verify gradients against PyTorch. Then assemble it with LayerNorm, residual connections, and an FFN into the full Pre-LN block from Section 12.10.

Further reading: “Attention Is All You Need” (Vaswani et al., 2017) — the original Transformer paper. “Neural Machine Translation by Jointly Learning to Align and Translate” (Bahdanau et al., 2014) for additive attention. “A Mathematical Framework for Transformer Circuits” (Elhage et al., 2021) and “In-context Learning and Induction Heads” (Olsson et al., 2022) for interpretability. “FlashAttention” (Dao et al., 2022) for the IO-aware implementation. Jay Alammar's “The Illustrated Transformer” for visual intuition.


Next → Chapter 13: The Transformer Architecture

You have now built every component of the Transformer: linear layers, activations, normalization, dropout, and residual connections (Chapter 10), the autograd to train them (Chapter 11), and multi-head attention (this chapter). Chapter 13 assembles them into the complete architecture — positional encodings, the full encoder and decoder stacks, the embedding and unembedding layers, and the precise data flow from input tokens to output logits. You will build a working Transformer from scratch and understand every line.

20 Exercises in this chapter
Attempt each exercise before checking the worked solutions.
View Solutions →