Module 03 · Section 3.3

Scaled Dot-Product & Multi-Head Attention

The Q/K/V framework, multi-head parallelism, and causal masking: the engine inside every Transformer

Why have one head when you can have eight, each looking at the sentence from a slightly different existential angle?

Hydra Hank, a multi-head attention enthusiast
★ Big Picture

From seq2seq attention to the Transformer's attention. In Section 3.2, we used attention to let a decoder peek at encoder states. The Transformer (Vaswani et al., 2017) takes this much further. It introduces the query/key/value (Q/K/V) abstraction, scales the dot products by √dk for numerical stability, runs multiple attention "heads" in parallel, and applies attention not just between encoder and decoder but also within a single sequence (self-attention). These building blocks are the heart of GPT, BERT, and every modern LLM. By the end of this section, you will have implemented multi-head self-attention from scratch and understood every piece of the mechanism that makes Transformers work.

1. The Query, Key, Value Abstraction

In Section 3.2, we described attention as a soft dictionary lookup: a query is compared against keys to produce weights, which are used to combine values. In Bahdanau and Luong attention, the keys and values were the same thing (encoder hidden states), and the query was the decoder state.

The Transformer formalizes and generalizes this. Given input vectors, it creates three separate representations through learned linear projections:

These are separate projections of the same (or different) input vectors. This decoupling is crucial: the information used for matching (Q and K) can differ from the information that gets passed forward (V). A position might have a key that says "I am a verb in past tense" (used for matching) while its value encodes the actual semantic meaning of that verb (used for the output).

2. Scaled Dot-Product Attention

Given query matrix Q, key matrix K, and value matrix V, the Transformer computes:

Attention(Q, K, V) = softmax(QKT / √dk) V

Let us break this formula apart:

  1. QKT: Computes dot-product similarity between every query and every key simultaneously. If Q has shape (n, dk) and K has shape (m, dk), this produces an (n, m) matrix of raw attention scores.
  2. Scaling by √dk: Divides each score by the square root of the key dimension. Without this scaling, the dot products would grow in magnitude with dk, pushing the softmax into saturated regions where its gradients are extremely small.
  3. Softmax: Converts each row into a probability distribution over key positions.
  4. Multiply by V: Uses the attention weights to take a weighted combination of value vectors.

Why Scale by √dk?

Consider two random vectors q and k, each with entries drawn from a standard normal distribution. Their dot product q · k = Σi qiki is a sum of dk independent products, each with mean 0 and variance 1. By the properties of sums of random variables, the dot product has mean 0 and variance dk. As dk grows, the typical magnitude of the dot product increases as √dk.

Large-magnitude inputs to softmax produce outputs very close to 0 or 1, with tiny gradients. Dividing by √dk restores the variance to approximately 1, keeping softmax in its sensitive, gradient-friendly regime.

import torch
import torch.nn.functional as F

torch.manual_seed(42)

# Demonstrate the scaling problem
for d_k in [8, 64, 512]:
    q = torch.randn(1, d_k)
    K = torch.randn(10, d_k)

    # Unscaled dot products
    scores_unscaled = q @ K.T
    # Scaled dot products
    scores_scaled = scores_unscaled / (d_k ** 0.5)

    probs_unscaled = F.softmax(scores_unscaled, dim=-1)
    probs_scaled   = F.softmax(scores_scaled, dim=-1)

    print(f"d_k={d_k:3d} | unscaled std={scores_unscaled.std():.2f}, "
          f"max prob={probs_unscaled.max():.4f} | "
          f"scaled std={scores_scaled.std():.2f}, "
          f"max prob={probs_scaled.max():.4f}")
d_k= 8 | unscaled std=2.38, max prob=0.5765 | scaled std=0.84, max prob=0.2213 d_k= 64 | unscaled std=7.89, max prob=0.9998 | scaled std=0.99, max prob=0.2697 d_k=512 | unscaled std=22.64, max prob=1.0000 | scaled std=1.00, max prob=0.2381

At dk = 512, the unscaled softmax is completely saturated (max probability is essentially 1.0, meaning all attention goes to a single position). The scaled version maintains a healthy distribution. This is not just a cosmetic issue; saturated softmax means near-zero gradients, which makes training extremely difficult.

📝 Softmax Temperature

Scaling by 1/√dk is equivalent to using a softmax with temperature T = √dk. Higher temperature produces softer (more uniform) distributions; lower temperature produces sharper (more peaked) ones. Some implementations allow an explicit temperature parameter for fine-grained control during inference, but during training the √dk scaling is standard.

Q K V MatMul (QKT) / √dk Mask? Softmax MatMul (αV) Output (n, dk) (m, dk) (m, dv)
Figure 3.6: Scaled dot-product attention. Q and K are multiplied, scaled, optionally masked, passed through softmax, then used to weight V. The optional mask is used for causal (autoregressive) attention.

Implementation from Scratch

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, n_queries, d_k)
    K: (batch, n_keys,    d_k)
    V: (batch, n_keys,    d_v)
    mask: (batch, n_queries, n_keys) or broadcastable, True = mask out
    Returns: output (batch, n_queries, d_v), weights (batch, n_queries, n_keys)
    """
    d_k = Q.size(-1)

    # Step 1: Compute raw attention scores
    scores = torch.bmm(Q, K.transpose(-2, -1))  # (batch, n_q, n_k)

    # Step 2: Scale
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))

    # Step 4: Softmax to get attention weights
    weights = F.softmax(scores, dim=-1)           # (batch, n_q, n_k)

    # Step 5: Weighted sum of values
    output = torch.bmm(weights, V)                # (batch, n_q, d_v)
    return output, weights

# Test: 4 queries attending to 6 key-value pairs
batch, n_q, n_k, d_k, d_v = 2, 4, 6, 32, 64
Q = torch.randn(batch, n_q, d_k)
K = torch.randn(batch, n_k, d_k)
V = torch.randn(batch, n_k, d_v)

out, wts = scaled_dot_product_attention(Q, K, V)
print(f"Output shape:  {out.shape}")       # (2, 4, 64)
print(f"Weights shape: {wts.shape}")       # (2, 4, 6)
print(f"Weights row 0 sums to: {wts[0, 0].sum():.4f}")
print(f"Weights[0,0]: {wts[0,0].detach().numpy().round(3)}")
Output shape: torch.Size([2, 4, 64]) Weights shape: torch.Size([2, 4, 6]) Weights row 0 sums to: 1.0000 Weights[0,0]: [0.086 0.301 0.155 0.024 0.212 0.222]

3. Self-Attention vs. Cross-Attention

The Q/K/V framework enables two fundamental modes of attention:

Self-Attention

In self-attention, the queries, keys, and values all come from the same sequence. Each position in the sequence attends to every other position (including itself). This allows each token to gather information from the entire input, building context-aware representations in a single operation.

Self-attention is what makes Transformers fundamentally different from RNNs. An RNN can only see past context (or future context, if bidirectional); self-attention sees all positions simultaneously. For a sentence like "The animal didn't cross the street because it was too tired," self-attention allows the model to connect "it" directly to "animal" regardless of distance.

Cross-Attention

In cross-attention, the queries come from one sequence (typically the decoder) while the keys and values come from a different sequence (typically the encoder). This is exactly the encoder-decoder attention from Section 3.2, reformulated in the Q/K/V framework. Cross-attention is what allows a Transformer decoder to "look at" the encoder output.

Property Self-Attention Cross-Attention
Q source Same sequence (X) Decoder states
K, V source Same sequence (X) Encoder outputs
Typical use Build contextual representations Combine encoder/decoder information
Score matrix shape (n, n), square (ndec, nenc), rectangular
Examples BERT, GPT encoder/decoder blocks Machine translation, T5 decoder

4. Causal Masking for Autoregressive Models

In autoregressive language models (like GPT), each token should only attend to tokens that appear before it in the sequence (and itself). It must not "peek" at future tokens that have not been generated yet. This constraint is enforced with a causal mask: an upper-triangular matrix of True values that sets future positions to -∞ before the softmax.

maskij = True   if   j > i   (future position)

After masking, the scores for future positions become -∞, which softmax maps to exactly 0. Each position can only attend to itself and earlier positions.

import torch

# Create a causal mask for sequence length 5
seq_len = 5
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
print("Causal mask (True = blocked):")
print(causal_mask.int())

# Apply to attention scores
scores = torch.randn(1, seq_len, seq_len)
scores_masked = scores.masked_fill(causal_mask.unsqueeze(0), float('-inf'))
weights = torch.softmax(scores_masked, dim=-1)

print("\nAttention weights (causal):")
print(weights[0].detach().numpy().round(3))
Causal mask (True = blocked): tensor([[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 0, 0]]) Attention weights (causal): [[1.000 0.000 0.000 0.000 0.000] [0.613 0.387 0.000 0.000 0.000] [0.248 0.505 0.247 0.000 0.000] [0.168 0.339 0.112 0.381 0.000] [0.041 0.298 0.371 0.106 0.184]]

Notice the triangular structure: position 0 can only attend to itself (weight 1.0), position 1 can attend to positions 0 and 1, and so on. The upper triangle is exactly zero, guaranteeing no information leakage from the future.

💡 Key Insight

Causal masking is what distinguishes GPT-style (decoder-only) models from BERT-style (encoder-only) models. BERT uses bidirectional self-attention (no mask), so every position can attend to every other position. GPT uses causal self-attention (with mask), so each position can only see the past. This difference determines what tasks each architecture is suited for: BERT excels at understanding (classification, NER), while GPT excels at generation (text completion, dialogue).

5. Multi-Head Attention

A single attention head can only capture one type of relationship at a time. If a word needs to attend to its syntactic head, its semantic role, and a coreferent pronoun simultaneously, a single attention distribution cannot represent all three patterns.

Multi-head attention solves this by running multiple attention operations in parallel, each with its own learned projections:

headi = Attention(XWiQ, XWiK, XWiV)
MultiHead(X) = Concat(head1, ..., headh) WO

Each head operates in a lower-dimensional subspace. If the model dimension is dmodel and there are h heads, each head works with dimension dk = dmodel / h. The outputs of all heads are concatenated and projected back to the full model dimension through WO.

Input X (n, dmodel) W1Q W1K W1V W2Q W2K W2V W3Q W3K W3V W4Q W4K W4V Head 1 Head 2 Head 3 Head 4 Concat (n, h × dk) = (n, dmodel) WO projection Output (n, dmodel) (n, dk) (n, dk) (n, dk) (n, dk)
Figure 3.7: Multi-head attention with h=4 heads. Each head independently projects the input into a lower-dimensional Q/K/V space, computes attention, and the results are concatenated and projected back to the full model dimension.

6. Lab: Implementing Multi-Head Self-Attention

Let us now build a complete, production-style multi-head self-attention module. This is the exact computation at the heart of every Transformer layer.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention with optional causal masking."""

    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads   # dimension per head

        # Combined Q, K, V projection (more efficient than three separate ones)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, causal=False):
        """
        x: (batch, seq_len, d_model)
        causal: if True, apply causal mask
        Returns: (batch, seq_len, d_model)
        """
        B, T, C = x.shape

        # Step 1: Project to Q, K, V
        qkv = self.qkv_proj(x)                         # (B, T, 3*C)
        Q, K, V = qkv.chunk(3, dim=-1)                  # each: (B, T, C)

        # Step 2: Reshape for multi-head: (B, T, C) -> (B, h, T, d_k)
        Q = Q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Step 3: Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (B, h, T, T)
        scores = scores / math.sqrt(self.d_k)

        # Step 4: Causal mask (optional)
        if causal:
            mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        # Step 5: Softmax + dropout
        weights = F.softmax(scores, dim=-1)            # (B, h, T, T)
        weights = self.dropout(weights)

        # Step 6: Weighted sum of values
        out = torch.matmul(weights, V)                  # (B, h, T, d_k)

        # Step 7: Concatenate heads: (B, h, T, d_k) -> (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # Step 8: Output projection
        out = self.out_proj(out)                         # (B, T, C)
        return out, weights

# Create and test the module
mha = MultiHeadSelfAttention(d_model=128, n_heads=4)
x = torch.randn(2, 10, 128)   # batch=2, seq_len=10, d_model=128

# Bidirectional (BERT-style)
out_bi, wts_bi = mha(x, causal=False)
print(f"Bidirectional output: {out_bi.shape}")
print(f"Weights shape:        {wts_bi.shape}")

# Causal (GPT-style)
out_ca, wts_ca = mha(x, causal=True)
print(f"Causal output:        {out_ca.shape}")

# Verify causal mask works: position 0 should have zero weight on all future positions
print(f"\nCausal weights for head 0, position 0:")
print(f"  {wts_ca[0, 0, 0].detach().numpy().round(4)}")
print(f"  (Only first entry is non-zero: position 0 attends only to itself)")

# Count parameters
params = sum(p.numel() for p in mha.parameters())
print(f"\nTotal parameters: {params:,}")
print(f"  QKV projection: {128 * 3 * 128 + 3 * 128:,}")
print(f"  Output projection: {128 * 128 + 128:,}")
Bidirectional output: torch.Size([2, 10, 128]) Weights shape: torch.Size([2, 4, 10, 10]) Causal output: torch.Size([2, 10, 128]) Causal weights for head 0, position 0: [1.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000] (Only first entry is non-zero: position 0 attends only to itself) Total parameters: 66,048 QKV projection: 49,536 Output projection: 16,512
⚠ Implementation Detail

In our implementation, we use a single linear layer (qkv_proj) to compute Q, K, and V simultaneously, then split the output into three parts. This is mathematically equivalent to using three separate linear layers but is more computationally efficient because it requires only one matrix multiplication instead of three. Most production implementations (PyTorch's nn.MultiheadAttention, Hugging Face Transformers) use this fused approach.

7. Complexity Analysis: The O(n²) Problem

Self-attention has a fundamental computational cost: the score matrix QKT has shape (n, n) where n is the sequence length. This means both the computation and memory required grow quadratically with sequence length.

Operation Time Complexity Space Complexity
Q, K, V projections O(n · d²) O(n · d)
QKT computation O(n² · d) O(n²)
Softmax O(n²) O(n²)
Attention × V O(n² · d) O(n · d)
Total O(n² · d) O(n² + n · d)

For typical LLM settings, n can be 2048, 8192, or even 128,000 tokens. The attention matrix alone for a 128K-token sequence would require 128,000 × 128,000 × 4 bytes ≈ 62 GB of memory per head. This quadratic scaling is the primary bottleneck that limits context lengths in Transformer models.

import torch, time

# Measure how attention scales with sequence length
d_model, n_heads = 128, 4
mha = MultiHeadSelfAttention(d_model, n_heads)
mha.eval()

print(f"{' seq_len':>10} {'time (ms)':>10} {'mem (MB)':>10} {'ratio':>8}")
prev_time = None
for seq_len in [64, 128, 256, 512, 1024, 2048]:
    x = torch.randn(1, seq_len, d_model)

    # Warm up
    with torch.no_grad():
        _ = mha(x)

    # Time it
    t0 = time.perf_counter()
    with torch.no_grad():
        for _ in range(20):
            _ = mha(x)
    elapsed = (time.perf_counter() - t0) / 20 * 1000

    # Memory for attention matrix
    attn_mem = seq_len * seq_len * n_heads * 4 / 1e6  # float32

    ratio = f"{elapsed / prev_time:.1f}x" if prev_time else ""
    print(f"{seq_len:>10} {elapsed:>10.2f} {attn_mem:>10.2f} {ratio:>8}")
    prev_time = elapsed
seq_len time (ms) mem (MB) ratio 64 0.34 0.07 128 0.52 0.26 1.5x 256 1.18 1.05 2.3x 512 3.87 4.19 3.3x 1024 14.23 16.78 3.7x 2048 55.41 67.11 3.9x

Each doubling of sequence length increases computation time by roughly 4x (approaching the theoretical quadratic scaling). The memory for attention matrices also grows quadratically. This is why modern LLM research invests heavily in techniques like Flash Attention, sparse attention, and linear attention approximations to tame this O(n²) cost.

📝 Looking Ahead

Despite the quadratic cost, self-attention has massive advantages over RNNs: (1) all positions are processed in parallel (no sequential bottleneck), (2) any two positions are connected by a single attention operation (constant path length, vs. O(n) for RNNs), and (3) the model can learn any interaction pattern rather than being constrained to sequential information flow. These advantages have made the O(n²) cost worth paying, and efficient attention variants continue to push the boundaries of what is practical.

8. Putting It All Together: Complete Example

Let us combine everything into a demonstration that shows multi-head self-attention operating on actual token embeddings, with visualization of what different heads learn:

import torch
import torch.nn as nn

# Simulate a small vocabulary and sentence
vocab = {"the": 0, "cat": 1, "sat": 2, "on": 3, "mat": 4}
sentence = ["the", "cat", "sat", "on", "mat"]
token_ids = torch.tensor([[vocab[w] for w in sentence]])

# Embedding + self-attention
d_model, n_heads = 64, 4
embedding = nn.Embedding(len(vocab), d_model)
attn = MultiHeadSelfAttention(d_model, n_heads)

# Forward pass
x = embedding(token_ids)                          # (1, 5, 64)
output, weights = attn(x, causal=True)            # causal for GPT-style

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}  (batch, heads, queries, keys)")

# Show what each head attends to for the word "mat" (last position)
print(f"\nAttention to generate representation of 'mat' (position 4):")
for h in range(n_heads):
    w = weights[0, h, 4].detach().numpy()
    top_pos = w.argmax()
    print(f"  Head {h}: {' '.join(f'{sentence[i]}:{w[i]:.2f}' for i in range(5))}"
          f"  (peak: '{sentence[top_pos]}')")

# Verify output is different from input (attention has mixed information)
cos_sim = nn.functional.cosine_similarity(x[0], output[0], dim=-1)
print(f"\nCosine similarity (input vs output) per position:")
for i, word in enumerate(sentence):
    print(f"  '{word}': {cos_sim[i]:.4f}")
Input shape: torch.Size([1, 5, 64]) Output shape: torch.Size([1, 5, 64]) Weights shape: torch.Size([1, 4, 5, 5]) (batch, heads, queries, keys) Attention to generate representation of 'mat' (position 4): Head 0: the:0.18 cat:0.08 sat:0.25 on:0.12 mat:0.37 (peak: 'mat') Head 1: the:0.04 cat:0.41 sat:0.09 on:0.33 mat:0.13 (peak: 'cat') Head 2: the:0.22 cat:0.15 sat:0.38 on:0.06 mat:0.19 (peak: 'sat') Head 3: the:0.09 cat:0.11 sat:0.17 on:0.48 mat:0.15 (peak: 'on') Cosine similarity (input vs output) per position: 'the': 0.2341 'cat': 0.1876 'sat': 0.3012 'on': 0.2543 'mat': 0.1698

Each head focuses on a different relationship. Head 1 connects "mat" primarily to "cat" (the subject), Head 2 connects it to "sat" (the verb), and Head 3 connects it to "on" (the preposition). This diversity is exactly why multiple heads are valuable: they let the model simultaneously capture syntactic, semantic, and positional relationships.

The low cosine similarity between input and output confirms that attention is doing substantial work: the output representations are very different from the raw embeddings, enriched with contextual information from other positions.

✍ Self-Check Quiz

1. Why is it important that Q, K, and V are separate projections rather than all being the same?
Show Answer
Separate projections allow the model to use different aspects of a token for matching (Q and K) versus information transfer (V). For example, a token's key might encode "I am a noun in subject position" (useful for matching), while its value encodes the actual meaning of the noun (useful for the output). If all three were the same, the model would be forced to use the same representation for both purposes, which is much less expressive.
2. If dmodel = 512 and n_heads = 8, what is dk? Does multi-head attention use more or fewer parameters than a single-head attention with dk = 512?
Show Answer
dk = 512 / 8 = 64. Multi-head attention uses approximately the same number of parameters as single-head attention. The Q, K, V projections map from 512 to 512 (= 8 × 64) in both cases, and the output projection also maps from 512 to 512. The difference is that multi-head attention factorizes the computation into 8 parallel, independent attention operations in 64-dimensional subspaces, which increases expressiveness without increasing parameter count.
3. What would happen if we removed the √dk scaling during training with dk = 512?
Show Answer
Without scaling, the dot products between Q and K vectors would have variance proportional to dk = 512 (standard deviation ~22.6). These large-magnitude inputs would push softmax into saturation, producing attention distributions that are nearly one-hot (all weight on a single position). The gradients through saturated softmax are extremely small, making training very slow or unstable. The model would have difficulty learning nuanced attention patterns and would tend to "hard-attend" to a single position.
4. Explain the difference between causal and bidirectional self-attention in terms of the mask and the resulting attention pattern.
Show Answer
Bidirectional (no mask): Every position can attend to every other position. The attention weight matrix is fully populated. Used in BERT-style encoder models. Causal (upper-triangular mask): Position i can only attend to positions 0, 1, ..., i. The upper triangle of the weight matrix is zero. Used in GPT-style decoder models. The causal mask is applied by setting future positions to -∞ before softmax, which maps them to exactly zero weight.
5. Why does self-attention have O(n²) complexity, and why is this problematic for long sequences?
Show Answer
The score matrix QKT has shape (n, n), requiring n² dot products to compute and n² floating-point numbers to store. Both computation and memory scale quadratically with sequence length. For a 100K-token sequence, this means 10 billion entries in the attention matrix. This quadratic scaling is problematic because (1) GPU memory is limited, capping the maximum sequence length, (2) longer sequences become disproportionately expensive, and (3) the cost eventually dominates all other operations in the Transformer. This is why extending context length (from 2K to 128K to 1M+ tokens) has required extensive engineering optimizations.

✓ Key Takeaways

  1. Q/K/V projections decouple what is used for matching (Q, K) from what is communicated (V), making attention far more expressive than using the same vectors for all roles.
  2. Scaling by √dk prevents dot products from growing with dimension, keeping softmax in a gradient-friendly regime. Without scaling, attention distributions saturate and training breaks down.
  3. Multi-head attention runs h independent attention operations in parallel, each in a lower-dimensional subspace. This allows the model to capture multiple types of relationships simultaneously without increasing parameter count.
  4. Self-attention computes Q, K, V from the same sequence, allowing every position to incorporate information from every other position. Cross-attention takes Q from one sequence and K, V from another.
  5. Causal masking restricts each position to attend only to earlier positions, enabling autoregressive generation. This is the key difference between GPT (causal) and BERT (bidirectional) architectures.
  6. O(n²) complexity in both time and memory is the primary scalability bottleneck of self-attention. Every pair of positions must be scored, limiting practical context lengths.
  7. What comes next: In Module 04, we will combine multi-head self-attention with feedforward layers, layer normalization, and residual connections to build the complete Transformer architecture.