Module 04, Section 4.2

Build a Transformer from Scratch

A complete, annotated decoder-only Transformer implementation in ~300 lines of PyTorch.

I built a Transformer from scratch and it predicted "the the the the." Honestly, some meetings feel the same way.

From Scratch Francesca, a humbled implementer
Lab: Hands-On Implementation

This section is a coding lab. By the end you will have a working character-level language model built on a decoder-only Transformer. Every line of code is explained. We encourage you to type the code yourself rather than copy-pasting; the act of typing builds muscle memory for these patterns.

1. What We Are Building

We will implement a decoder-only Transformer (the GPT architecture) that performs character-level language modeling. Given a sequence of characters, the model predicts the next character at every position. We choose character-level modeling because it eliminates the need for a tokenizer, letting us focus entirely on the architecture.

Our model will have these hyperparameters:

HyperparameterValueNotes
d_model128Embedding and residual stream dimension
n_heads4Number of attention heads (d_k = 32)
n_layers4Number of Transformer blocks
d_ff512Feed-forward inner dimension (4 × d_model)
block_size128Maximum context length
vocab_size~65Unique characters in the dataset
dropout0.1Dropout rate

This is a small model (~1.6M parameters) that trains in a few minutes on a single GPU (or even on CPU for a few epochs). The architecture is identical to GPT; only the scale differs.

Token + Position Embedding Transformer Block (x N) LayerNorm Causal Multi-Head Self-Attention LayerNorm Feed-Forward (SwiGLU or ReLU) + res + res Final LayerNorm Linear (d_model → vocab_size) Softmax → Probabilities
Figure 4.4: Architecture of our decoder-only Transformer. N blocks of self-attention + FFN with Pre-LN ordering, followed by a final normalization and linear projection.

2. The Complete Implementation

Below is the full model in a single file. We break it into logical pieces and explain each one. The complete code (all pieces assembled) is approximately 300 lines including comments.

2.1 Imports and Configuration

"""
mini_transformer.py
A minimal decoder-only Transformer for character-level language modeling.
~300 lines of annotated PyTorch.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class TransformerConfig:
    """All hyperparameters in one place."""
    vocab_size: int = 65        # number of unique characters
    block_size: int = 128       # maximum context length
    n_layers: int = 4           # number of Transformer blocks
    n_heads: int = 4            # number of attention heads
    d_model: int = 128          # embedding / residual stream dimension
    d_ff: int = 512             # feed-forward inner dimension
    dropout: float = 0.1        # dropout probability
    bias: bool = False          # use bias in Linear layers?

We use a dataclass so that every hyperparameter is explicit, documented, and easy to modify. Setting bias=False follows the LLaMA convention and marginally reduces parameter count.

2.2 Causal Self-Attention

class CausalSelfAttention(nn.Module):
    """Multi-head causal (masked) self-attention."""

    def __init__(self, config: TransformerConfig):
        super().__init__()
        assert config.d_model % config.n_heads == 0

        # Key, Query, Value projections combined into one matrix
        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
        # Output projection
        self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.d_k = config.d_model // config.n_heads

        # Causal mask: lower-triangular boolean matrix
        # Register as buffer so it moves to GPU with the model
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.shape  # batch, sequence length, d_model

        # Compute Q, K, V in one matrix multiply, then split
        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.d_model, dim=2)

        # Reshape for multi-head: (B, T, C) -> (B, n_heads, T, d_k)
        q = q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        # (B, n_heads, T, d_k) @ (B, n_heads, d_k, T) -> (B, n_heads, T, T)
        scores = (q @ k.transpose(-2, -1)) * (self.d_k ** -0.5)

        # Apply causal mask: positions beyond current token get -inf
        scores = scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Weighted sum of values
        # (B, n_heads, T, T) @ (B, n_heads, T, d_k) -> (B, n_heads, T, d_k)
        out = attn_weights @ v

        # Concatenate heads: (B, n_heads, T, d_k) -> (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # Final linear projection + dropout
        return self.resid_dropout(self.out_proj(out))
Key Insight: Fused QKV Projection

We compute Q, K, and V with a single linear layer (qkv_proj) of size d_model → 3 * d_model and then split the output into three equal parts. This is mathematically identical to three separate linear layers but is more efficient because it performs one large matrix multiply instead of three smaller ones. The GPU utilizes its parallelism more effectively with larger matrices.

2.3 Feed-Forward Network

class FeedForward(nn.Module):
    """Position-wise feed-forward network with ReLU activation."""

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
        self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

This is the simplest version. For a more advanced variant, you can swap in SwiGLU:

class SwiGLUFeedForward(nn.Module):
    """SwiGLU feed-forward (used in LLaMA, PaLM)."""

    def __init__(self, config: TransformerConfig):
        super().__init__()
        # SwiGLU uses 3 weight matrices instead of 2
        # To keep param count comparable, the hidden dim is often 2/3 of d_ff
        hidden = int(2 * config.d_ff / 3)
        self.w1 = nn.Linear(config.d_model, hidden, bias=config.bias)
        self.w2 = nn.Linear(hidden, config.d_model, bias=config.bias)
        self.w3 = nn.Linear(config.d_model, hidden, bias=config.bias)  # gate
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        # SiLU(x * W1) * (x * W3) then project back
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

2.4 Transformer Block

class TransformerBlock(nn.Module):
    """A single Transformer block with Pre-LN ordering."""

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.ffn = FeedForward(config)

    def forward(self, x):
        # Pre-LN: normalize before each sub-layer
        x = x + self.attn(self.ln1(x))    # residual + attention
        x = x + self.ffn(self.ln2(x))     # residual + FFN
        return x

This is remarkably simple. Two lines of actual computation, each following the pattern: x = x + SubLayer(LayerNorm(x)). The residual connection is the x + at the beginning; the Pre-LN ordering means we normalize the input to each sub-layer, not the output.

2.5 The Complete Model

class MiniTransformer(nn.Module):
    """Decoder-only Transformer for character-level language modeling."""

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config

        # Token and position embeddings
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.block_size, config.d_model)
        self.drop = nn.Dropout(config.dropout)

        # Stack of Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        # Final layer norm (needed with Pre-LN)
        self.ln_final = nn.LayerNorm(config.d_model)

        # Output head: project from d_model to vocab_size
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Weight tying: share embedding and output weights
        self.token_emb.weight = self.lm_head.weight

        # Initialize weights
        self.apply(self._init_weights)
        # Scale residual projections
        for block in self.blocks:
            nn.init.normal_(
                block.attn.out_proj.weight,
                mean=0.0,
                std=0.02 / math.sqrt(2 * config.n_layers)
            )
            nn.init.normal_(
                block.ffn.fc2.weight,
                mean=0.0,
                std=0.02 / math.sqrt(2 * config.n_layers)
            )

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, idx, targets=None):
        """
        Args:
            idx: (B, T) tensor of token indices
            targets: (B, T) tensor of target token indices (optional)
        Returns:
            logits: (B, T, vocab_size)
            loss: scalar cross-entropy loss (only if targets provided)
        """
        B, T = idx.shape
        assert T <= self.config.block_size, \
            f"Sequence length {T} exceeds block_size {self.config.block_size}"

        # Token embeddings + positional embeddings
        positions = torch.arange(0, T, device=idx.device)  # (T,)
        x = self.token_emb(idx) + self.pos_emb(positions)  # (B, T, d_model)
        x = self.drop(x)

        # Pass through all Transformer blocks
        for block in self.blocks:
            x = block(x)

        # Final normalization
        x = self.ln_final(x)

        # Project to vocabulary
        logits = self.lm_head(x)  # (B, T, vocab_size)

        # Compute loss if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Auto-regressive generation.
        Args:
            idx: (B, T) conditioning sequence
            max_new_tokens: number of tokens to generate
            temperature: softmax temperature (lower = more deterministic)
            top_k: if set, only sample from top-k most likely tokens
        """
        for _ in range(max_new_tokens):
            # Crop context to block_size if needed
            idx_cond = idx[:, -self.config.block_size:]

            # Forward pass
            logits, _ = self(idx_cond)

            # Take logits at the last position and apply temperature
            logits = logits[:, -1, :] / temperature

            # Optional top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Sample from the distribution
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append to sequence
            idx = torch.cat([idx, next_token], dim=1)

        return idx
Note: Weight Tying

The line self.token_emb.weight = self.lm_head.weight shares the embedding matrix with the output projection. This is standard practice in language models. It means the model uses the same representation for "what does this token mean?" (embedding) and "what token should come next?" (output logits). This reduces parameter count by vocab_size × d_model and provides a regularization effect.

Paper Spotlight: Weight Tying (Press and Wolf, 2017)

Press and Wolf showed that tying the input embedding and output projection weights is not just a memory optimization; it acts as a regularizer that improves perplexity. The intuition: by forcing the model to use a single vector space for both input and output, it learns embeddings where tokens that should be predicted in similar contexts also have similar input representations. For a 50K vocabulary with d=512, weight tying saves 25 million parameters. Nearly all modern language models (GPT-2, GPT-3, LLaMA, Mistral) use this technique.

Press, O. & Wolf, L. (2017). "Using the Output Embedding to Improve Language Models." EACL 2017.

3. Data Preparation

For training, we use a simple character-level dataset. Any plain text file will work. We will use a small text corpus (a few hundred KB) for quick experimentation.

class CharDataset:
    """Character-level dataset that produces (input, target) pairs."""

    def __init__(self, text, block_size):
        self.block_size = block_size
        # Build character vocabulary
        chars = sorted(set(text))
        self.vocab_size = len(chars)
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for ch, i in self.stoi.items()}
        # Encode entire text as integers
        self.data = torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        chunk = self.data[idx : idx + self.block_size + 1]
        x = chunk[:-1]   # input:  characters 0..block_size-1
        y = chunk[1:]     # target: characters 1..block_size
        return x, y

    def decode(self, indices):
        """Convert list of integer indices back to string."""
        return ''.join(self.itos[i] for i in indices)

    def encode(self, text):
        """Convert string to list of integer indices."""
        return [self.stoi[c] for c in text]

4. The Training Loop

📝 Connecting the Pieces: Next-Token Prediction Is Classification

Next-token prediction is classification. At each position in the sequence, the model performs a V-way classification over the entire vocabulary, where V is the vocabulary size. The cross-entropy loss from Section 0.1 applies directly here: we compare the model's predicted probability distribution over all possible next tokens against the one-hot target (the actual next token in the training data). This is why the code below uses F.cross_entropy to compute the loss, treating every position as an independent classification problem.

import time
from torch.utils.data import DataLoader


def train(config=None, text_path='input.txt', max_steps=5000,
          batch_size=64, learning_rate=3e-4, eval_interval=500):
    """Complete training procedure."""

    # ---- Setup ----
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    # Load text data
    with open(text_path, 'r', encoding='utf-8') as f:
        text = f.read()

    # Create dataset
    dataset = CharDataset(text, block_size=128)
    print(f"Vocabulary size: {dataset.vocab_size}")
    print(f"Dataset size: {len(dataset):,} examples")

    # Create config with correct vocab size
    if config is None:
        config = TransformerConfig(vocab_size=dataset.vocab_size)
    else:
        config.vocab_size = dataset.vocab_size

    # Create model
    model = MiniTransformer(config).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {n_params:,}")

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        betas=(0.9, 0.95),
        weight_decay=0.1
    )

    # Data loader
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
                        num_workers=0, pin_memory=True)
    data_iter = iter(loader)

    # ---- Training ----
    model.train()
    t0 = time.time()

    for step in range(max_steps):
        # Get batch (cycle through data)
        try:
            xb, yb = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            xb, yb = next(data_iter)

        xb, yb = xb.to(device), yb.to(device)

        # Forward pass
        logits, loss = model(xb, yb)

        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # Gradient clipping (standard practice)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        # Logging
        if step % eval_interval == 0 or step == max_steps - 1:
            dt = time.time() - t0
            print(f"step {step:5d} | loss {loss.item():.4f} | "
                  f"time {dt:.1f}s")

    # ---- Generation ----
    model.eval()
    prompt = "\n"
    context = torch.tensor(
        [dataset.encode(prompt)], dtype=torch.long, device=device
    )
    generated = model.generate(context, max_new_tokens=500, temperature=0.8)
    print("\n" + "=" * 50)
    print("Generated text:")
    print("=" * 50)
    print(dataset.decode(generated[0].tolist()))

    return model, dataset


if __name__ == '__main__':
    train()
Practical Tips

Gradient clipping (clip_grad_norm_ with max_norm=1.0) prevents training instability from occasional large gradient spikes. This is standard in all Transformer training pipelines.

AdamW (Adam with decoupled weight decay) is the optimizer of choice. The betas (0.9, 0.95) and weight_decay (0.1) follow common LLM training conventions. The learning rate 3e-4 works well for small models; larger models typically use lower rates with warmup schedules.

5. Understanding the Shapes

Tracking tensor shapes is one of the most valuable debugging skills when working with Transformers. Here is a shape trace through the forward pass:

VariableShapeDescription
idx(B, T)Input token indices
token_emb(idx)(B, T, d_model)Token embeddings
pos_emb(positions)(T, d_model)Positional embeddings (broadcast over B)
x after embedding(B, T, d_model)Sum of token + position embeddings
qkv(B, T, 3*d_model)Fused QKV projection output
q, k, v after reshape(B, n_heads, T, d_k)Per-head queries, keys, values
scores(B, n_heads, T, T)Attention scores (before masking)
attn_weights(B, n_heads, T, T)Attention probabilities (after softmax)
out from attention(B, T, d_model)Concatenated head outputs after out_proj
ffn output(B, T, d_model)Feed-forward output
logits(B, T, vocab_size)Raw prediction scores for each position
Key Insight: The T x T Attention Matrix

The attention scores have shape (B, n_heads, T, T). This is where the quadratic cost of attention lives. For T=128, this is 128 × 128 = 16,384 entries per head per example. For T=4096 (a moderate context window), that grows to 16.7 million. Section 4.3 covers techniques to reduce this cost.

6. Running the Lab

6.1 Getting Data

Download a small text file for training. Shakespeare's collected works (~1.1 MB) is the classic choice:

# Download the tiny Shakespeare dataset
import urllib.request
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, "input.txt")

6.2 Training

# Train with default settings
model, dataset = train(max_steps=5000)

# Expected output after 5000 steps (loss around 1.4-1.5):
# step     0 | loss 4.1742 | time 0.0s
# step   500 | loss 1.9831 | time 12.3s
# step  1000 | loss 1.6524 | time 24.7s
# ...
# step  5000 | loss 1.4208 | time 123.5s

6.3 Evaluating the Output

After training, the model will generate text that resembles the style of the training data. At ~5000 steps with our small model, you should see recognizable words, approximate sentence structure, and character-level patterns that match the training corpus. The text will not be coherent, but it should clearly be "trying" to produce English in the style of the training data.

6.4 Experiments to Try

7. Common Bugs and Debugging

When implementing Transformers from scratch, certain bugs appear repeatedly. Here are the most common ones and how to detect them:

SymptomLikely CauseFix
Loss stays flat at ~ln(vocab_size) Gradients are not flowing; possible shape mismatch or detached computation Check that no .detach() calls break the computation graph. Verify loss computation.
Loss drops fast then NaN Learning rate too high or no gradient clipping Add gradient clipping (max_norm=1.0). Reduce learning rate. Check for missing layer norm.
Generated text is repetitive gibberish Missing or incorrect causal mask Verify the mask is lower-triangular and correctly applied before softmax.
Generated text is random characters Insufficient training or broken positional encoding Train longer. Verify pos_emb is added, not concatenated.
All generated tokens are the same Temperature too low or top_k=1 Increase temperature. Use top_k > 1 or remove top_k filtering.
Debugging Tip: Overfit a Tiny Dataset First

Before training on the full dataset, verify your model can overfit a single batch. Take one batch of data and train for 100 steps. The loss should drop to near zero. If it does not, there is a bug in your model or training loop. This simple sanity check saves hours of debugging.

Key Takeaways

Check Your Understanding

1. Why do we combine Q, K, V into a single linear projection rather than using three separate layers?

Show Answer
A single large matrix multiply (d_model to 3*d_model) is more GPU-efficient than three smaller ones (d_model to d_model each). The GPU better utilizes its parallelism with larger matrices. The result is mathematically identical.

2. What does weight tying do and why is it beneficial?

Show Answer
Weight tying shares the token embedding matrix with the output projection matrix. Both map between d_model-dimensional space and vocabulary space. Sharing them reduces parameter count by vocab_size * d_model parameters and provides a useful inductive bias: tokens with similar embeddings will have similar output logits.

3. Why is the final LayerNorm necessary in a Pre-LN Transformer?

Show Answer
In Pre-LN ordering, each sub-layer normalizes its input but not its output. The residual connection adds the (unnormalized) sub-layer output back to the stream. After the last block, the residual stream has not been normalized. The final LayerNorm ensures the representations have stable statistics before the output projection to vocabulary logits.

4. What would happen if you removed the causal mask during training?

Show Answer
Without the causal mask, each position can attend to future tokens. The training loss would drop quickly because the model can "cheat" by looking ahead. However, at generation time future tokens do not exist yet, so the model would produce poor output. The causal mask ensures training conditions match inference conditions.

5. The attention scores tensor has shape (B, n_heads, T, T). What does each element represent?

Show Answer
Element [b, h, i, j] is the (unnormalized) attention score from position i (the query) to position j (the key) in head h of example b. After softmax over the last dimension, it becomes the weight that position j's value contributes to the output at position i. The causal mask sets entries where j > i to negative infinity.