Module 04, Section 4.3

Transformer Variants & Efficiency

From BERT to Mamba: the family tree of sequence architectures and the quest for efficient attention.

BERT reads both ways, GPT only looks left, and T5 just converts everything into a text-to-text problem. Family dinners are awkward.

Model Megan, a variant cataloger

1. The Three Architectural Families

The original Transformer is an encoder-decoder model, but since 2017 three distinct families have emerged. Each family uses a different subset of the Transformer's components and targets different types of tasks. Understanding when to use which architecture is a fundamental skill in applied NLP.

Encoder-Only Bidirectional attention Self-Attention (full) Feed-Forward Models: BERT, RoBERTa DeBERTa, ELECTRA Tasks: Classification, NER, retrieval, embeddings Decoder-Only Causal (left-to-right) attention Masked Self-Attention Feed-Forward Models: GPT, LLaMA, Claude Mistral, Gemma, Qwen Tasks: Text generation, chat, code, reasoning Encoder-Decoder Full + causal + cross Encoder Self-Attention Decoder Self-Attention (masked) Cross-Attention + FFN Models: T5, BART, UL2 Whisper, mBART Tasks: Translation, summarization, ASR
Figure 4.5: The three Transformer architectural families. The decoder-only pattern dominates modern LLMs.

1.1 Encoder-Only (BERT Family)

Encoder-only models use bidirectional attention: every token can attend to every other token, including those to its right. This makes them excellent for understanding tasks (classification, token-level labeling, sentence similarity) but unsuitable for generation, since there is no causal structure.

BERT (Devlin et al., 2018) is trained with masked language modeling (MLM): 15% of input tokens are randomly masked, and the model must predict them. This pre-training objective forces the model to build rich bidirectional representations. Key descendants include RoBERTa (larger data, longer training), DeBERTa (disentangled attention with relative position), and ELECTRA (replaced-token detection, which is more sample-efficient than MLM).

1.2 Decoder-Only (GPT Family)

Decoder-only models use causal (left-to-right) attention with the auto-regressive language modeling objective. They are the dominant architecture for modern LLMs because: (a) the training objective is simple, scalable, and naturally aligns with text generation; (b) they can be prompted to perform virtually any task (classification, translation, reasoning) through in-context learning; and (c) they are straightforward to scale.

GPT-2 (Radford et al., 2019) demonstrated that a sufficiently large language model develops emergent abilities. GPT-3 (Brown et al., 2020) showed that these abilities improve reliably with scale. Today, virtually all frontier LLMs (GPT-4, Claude, LLaMA, Gemini, Mistral) are decoder-only.

1.3 Encoder-Decoder (T5 Family)

Encoder-decoder models process an input sequence with bidirectional attention (encoder) and generate an output sequence with causal attention (decoder), using cross-attention to condition the decoder on the encoder's output. This is the original Transformer architecture and remains the best choice for tasks with a clear input/output structure where the input benefits from bidirectional processing: translation, summarization, and speech recognition (Whisper).

T5 (Raffel et al., 2020) reframed all NLP tasks as text-to-text problems, demonstrating that a single encoder-decoder model could handle classification, translation, summarization, and question answering with the same architecture, just different input/output text formats.

2. Positional Encoding Variants

2.1 Rotary Position Embedding (RoPE)

RoPE (Su et al., 2021) has become the dominant positional encoding in modern LLMs (used in LLaMA, Mistral, Qwen, Gemma). Instead of adding positional information to the input embeddings, RoPE applies a rotation to the query and key vectors in each attention head. The rotation angle depends on the position and the dimension index.

The key insight: after applying RoPE to queries and keys, their dot product depends only on their relative position, not their absolute positions. This is achieved by rotating pairs of dimensions by pos × θi:

RoPE(x, pos) applies a 2D rotation of angle pos × θi to each consecutive pair (x2i, x2i+1)
def apply_rope(x, freqs_cos, freqs_sin):
    """Apply Rotary Position Embedding to queries or keys.
    x: (B, n_heads, T, d_k)
    freqs_cos, freqs_sin: (T, d_k//2) precomputed cos/sin of rotation angles
    """
    # Split into pairs and rotate
    x_r = x.float().reshape(*x.shape[:-1], -1, 2)  # (..., d_k//2, 2)
    x0, x1 = x_r[..., 0], x_r[..., 1]

    cos = freqs_cos.unsqueeze(0).unsqueeze(0)  # broadcast over B, n_heads
    sin = freqs_sin.unsqueeze(0).unsqueeze(0)

    # 2D rotation: [cos -sin; sin cos] @ [x0; x1]
    out0 = x0 * cos - x1 * sin
    out1 = x0 * sin + x1 * cos

    out = torch.stack([out0, out1], dim=-1).flatten(-2)
    return out.type_as(x)
RoPE: Rotation Encodes Position Each pair of dimensions is rotated by an angle proportional to position pos=0 No rotation pos=3 Rotated by 3θ Key Property: q(pos=5) · k(pos=3) depends only on relative distance = 2
RoPE rotates each pair of embedding dimensions by an angle proportional to the token's position. Because the dot product of two rotated vectors depends only on the angle between them (the relative position), RoPE naturally captures relative position without explicit position IDs.

RoPE advantages: (1) it naturally encodes relative positions, (2) it requires no additional parameters, (3) it can be extended to longer sequences through frequency scaling (NTK-aware scaling, YaRN), and (4) it has strong empirical performance.

2.2 ALiBi (Attention with Linear Biases)

ALiBi (Press et al., 2022) takes a minimalist approach: it adds a linear bias to the attention scores that penalizes distant positions. No positional encoding is added to the embeddings at all. For head h, a bias of -mh · |i - j| is added to the attention score between positions i and j, where mh is a head-specific slope. ALiBi provides strong length extrapolation and is used in BLOOM and some Falcon variants.

3. Efficient Attention Mechanisms

Standard attention has O(T2) time and memory complexity, where T is the sequence length. For T=128K (a common context window in 2024+ models), the naive attention matrix would be 128K × 128K = 16 billion entries per head. This section surveys the main approaches to making attention tractable at long sequences.

3.1 Sparse Attention

Instead of attending to all positions, sparse attention restricts each token to attend to a carefully chosen subset of positions. The challenge is choosing which positions to attend to while preserving the model's ability to capture long-range dependencies.

Note: Mistral's Sliding Window Attention

Mistral 7B uses a sliding window of 4096 tokens. Because information propagates through residual connections across layers, a model with N layers and window size w can theoretically propagate information across N × w positions. With 32 layers and w=4096, that is 131,072 positions of effective reach.

3.2 Linear Attention

Linear attention replaces the softmax kernel with a decomposable kernel function, allowing the attention computation to be rewritten in O(T) time:

Standard: softmax(QKT) V   [O(T2)]
Linear: φ(Q) (φ(K)T V)   [O(T)]

The trick is to compute φ(K)TV first (a d × d matrix, independent of T), then multiply by φ(Q). The feature map φ can be the identity (giving a simple outer-product formulation), an exponential, or a random feature approximation. Linear attention has seen renewed interest through architectures like RWKV and RetNet (discussed below).

3.3 FlashAttention

FlashAttention (Dao et al., 2022) is not an approximation; it computes exact standard attention but with dramatically better hardware utilization. The key insight is that the standard attention implementation is memory-bound: it writes the full T × T attention matrix to GPU global memory (HBM), reads it back for the softmax, writes it again, and reads it for the value multiplication. FlashAttention fuses these operations into a single kernel that keeps the attention matrix in fast on-chip SRAM, never materializing the full T × T matrix in HBM.

The result: 2 to 4x wall-clock speedup and dramatically reduced memory usage (from O(T2) to O(T) HBM). FlashAttention-2 further optimizes the kernel with better work partitioning across GPU thread blocks. FlashAttention-3 (2024) leverages Hopper GPU features (warp specialization, FP8 tensor cores) for additional gains. We cover the algorithm in detail in Section 4.4.

3.4 Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)

During auto-regressive inference, the keys and values from all previous positions are cached (the "KV cache"). With standard multi-head attention and many heads, this cache becomes enormous.

class GroupedQueryAttention(nn.Module):
    """GQA: groups of query heads share KV heads."""

    def __init__(self, d_model, n_heads, n_kv_heads):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads  # how many Q heads per KV head
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, n_heads * self.d_k, bias=False)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, T, _ = x.shape

        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)

        # Repeat KV heads to match the number of query heads
        # (B, n_kv_heads, T, d_k) -> (B, n_heads, T, d_k)
        k = k.repeat_interleave(self.n_rep, dim=1)
        v = v.repeat_interleave(self.n_rep, dim=1)

        scores = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(out)

4. Beyond Attention: State Space Models

State Space Models (SSMs) represent a fundamentally different approach to sequence modeling. Instead of computing pairwise attention between all tokens, SSMs process sequences through a linear recurrence with continuous-time dynamics. The key innovation is that this recurrence can be computed in O(T log T) time during training (using a parallel scan or convolution) while maintaining O(1) per-step cost during inference.

4.1 The S4 Foundation

S4 (Gu et al., 2022) models the mapping from input u(t) to output y(t) through a linear state-space equation:

h'(t) = Ah(t) + Bu(t) ,    y(t) = Ch(t) + Du(t)

After discretization with step size Δ, this becomes a linear recurrence: ht = A̅ht-1 + B̅ut. The discrete recurrence can also be unrolled as a convolution (over the full sequence), enabling parallel training. S4's breakthrough was showing how to parameterize the matrix A (using the HiPPO initialization) to capture long-range dependencies that RNNs struggle with.

4.2 Mamba: Selective State Spaces

Mamba (Gu and Dao, 2023) introduced selective state spaces, where the parameters B, C, and Δ are input-dependent (functions of the current token). This breaks the linear time-invariance that allows convolution-based training, but the authors developed a hardware-aware parallel scan algorithm that achieves efficient training on GPUs.

Mamba's advantages over Transformers:

Mamba's disadvantages:

Big Picture: Attention vs. SSMs

Attention gives every token direct access to every other token (complete lookback) but at quadratic cost. SSMs compress the history into a fixed-size state (lossy compression) but at linear cost. The emerging trend in 2024/2025 is hybrid architectures that interleave SSM layers with attention layers (Jamba, Zamba, Samba), getting the best of both worlds: efficient processing for most of the context with selective exact retrieval where needed.

4.3 Mamba-2 and the Connection to Attention

Mamba-2 (Dao and Gu, 2024) revealed a deep connection between SSMs and attention. The structured state space duality (SSD) framework shows that selective SSMs are equivalent to a form of structured masked attention, where the mask has a specific semiseparable structure. This unification opens the door to transferring optimization techniques between the two paradigms.

5. RWKV: Linear Attention as an RNN

RWKV (Peng et al., 2023) combines the training parallelism of Transformers with the inference efficiency of RNNs. It uses a variant of linear attention with time-dependent decay, formulated so that it can be computed either as a parallel attention-like operation (for training) or as a sequential RNN (for inference).

The core idea: replace softmax attention with a weighted sum using exponentially decaying weights:

wkvt = (∑i=1..t-1 e-(t-1-i)w + ki vi + eu + kt vt) / (∑i=1..t-1 e-(t-1-i)w + ki + eu + kt)

Here w is a learned decay factor (how quickly the model "forgets" past tokens) and u is a bonus for the current token. This can be computed in O(T) during both training and inference. RWKV-6 and later versions add more sophisticated mechanisms (data-dependent linear interpolation, multi-scale decay) while maintaining linear complexity.

6. Mixture-of-Experts (MoE)

Mixture-of-Experts is an orthogonal scaling strategy: instead of making every layer wider, you create multiple "expert" sub-networks and route each token to only a few of them. This dramatically increases the total parameter count (and thus model capacity) while keeping the computation per token roughly constant.

6.1 Architecture

In a typical MoE Transformer, the FFN in each block is replaced by a set of E expert FFNs plus a router (gating network). For each token, the router selects the top-k experts (typically k=1 or k=2), and the token is processed only by those selected experts. The output is a weighted combination of the expert outputs.

class MoELayer(nn.Module):
    """Mixture-of-Experts feed-forward layer."""

    def __init__(self, d_model, d_ff, n_experts, top_k=2):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k

        # Router: maps d_model -> n_experts (logits for each expert)
        self.router = nn.Linear(d_model, n_experts, bias=False)

        # Expert FFNs
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff, bias=False),
                nn.SiLU(),
                nn.Linear(d_ff, d_model, bias=False),
            )
            for _ in range(n_experts)
        ])

    def forward(self, x):
        B, T, C = x.shape
        x_flat = x.view(-1, C)  # (B*T, C)

        # Compute routing weights
        router_logits = self.router(x_flat)              # (B*T, n_experts)
        weights, indices = torch.topk(router_logits, self.top_k, dim=-1)
        weights = torch.softmax(weights, dim=-1)          # normalize top-k

        # Dispatch tokens to experts and combine results
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_idx = indices[:, k]                    # (B*T,)
            weight = weights[:, k].unsqueeze(-1)          # (B*T, 1)
            for e in range(self.n_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.experts[e](expert_input)
                    output[mask] += weight[mask] * expert_output

        return output.view(B, T, C)
Load Balancing

A naive router tends to collapse: it learns to send all tokens to the same few experts, leaving most experts unused. To prevent this, MoE models add an auxiliary load-balancing loss that penalizes imbalanced expert utilization. The balancing loss encourages the router to distribute tokens roughly equally across experts. Mixtral, Switch Transformer, and GShard all use variants of this technique.

6.2 Notable MoE Models

ModelTotal ParamsActive ParamsExpertsTop-k
Switch Transformer1.6T~100B1281
Mixtral 8x7B46.7B12.9B82
Mixtral 8x22B176B39B82
DeepSeek-V2236B21B1606
Qwen2-MoE57B14.3B648 (shared + routed)

The key insight: a 46.7B parameter MoE model (Mixtral 8x7B) can match or exceed a dense 70B model in quality while using only 12.9B parameters of computation per token. This means it runs at roughly the speed of a 13B dense model despite having the knowledge capacity of a much larger one.

7. Gated Attention and Gated Linear Units

The concept of gating (element-wise multiplication of two parallel pathways) has become ubiquitous in modern Transformers, appearing in both the FFN and the attention mechanism.

7.1 Gated FFN Variants

The standard ReLU FFN computes ReLU(xW1)W2. Gated variants split the first projection into two branches and multiply them:

Why does gating help? The gate branch learns to selectively amplify or suppress features produced by the value branch. This provides a richer form of nonlinearity than a single activation function and consistently improves performance at a given compute budget.

7.2 Gated Attention Units

GAU (Hua et al., 2022) applies the gating principle to attention itself. Instead of the standard residual attention pattern, GAU computes:

output = (U ⊙ AttentionOutput(V)) Wo

where U is a gating signal derived from the input, and V is the value signal. This allows single-head attention to be competitive with multi-head attention, since the gate provides the diversity that multiple heads normally provide. GAU reduces both the number of attention heads needed and the overall parameter count.

8. Multi-Head Latent Attention (MLA)

Multi-Head Latent Attention (DeepSeek-V2, 2024) addresses the KV cache bottleneck through a different lens than GQA. Instead of sharing KV heads across groups, MLA compresses the keys and values into a low-dimensional latent space before caching.

8.1 How MLA Works

In standard attention, we cache the full K and V tensors of shape (T, n_heads, d_k). MLA instead caches a compressed representation cKV of shape (T, d_c) where d_c << n_heads × d_k. The full K and V are reconstructed from the compressed representation when needed:

cKV = x Wcompress    (cached)
K = cKV WUK ,   V = cKV WUV    (reconstructed on the fly)

The compression ratio can be 4x to 16x, dramatically reducing the KV cache memory. The decompression matrices WUK and WUV are small and fast to apply. DeepSeek-V2 reports comparable quality to standard multi-head attention while reducing KV cache memory by 93%.

Key Insight: The KV Cache Hierarchy

There is a progression of techniques for reducing KV cache size: standard MHA (full cache) → GQA (share KV across groups, ~4x reduction) → MQA (single KV, ~32x reduction) → MLA (compressed latent, 4x to 16x with less quality loss than MQA). Each trades off differently between cache size, computation, and model quality.

9. Putting It All Together: Modern LLM Recipes

No production LLM uses a single technique in isolation. Here is how several prominent models combine the building blocks discussed in this section:

Model Architecture Position Enc. Attention FFN Special
LLaMA 3 Decoder-only RoPE GQA SwiGLU Pre-LN (RMSNorm)
Mistral 7B Decoder-only RoPE GQA + sliding window SwiGLU Pre-LN (RMSNorm)
Mixtral 8x7B Decoder-only MoE RoPE GQA + sliding window SwiGLU MoE (8 experts) Top-2 routing
DeepSeek-V2 Decoder-only MoE RoPE (YaRN) MLA MoE (160 experts) Top-6 routing, shared experts
Jamba Hybrid SSM+Attn RoPE (Attn layers) GQA (some layers) SwiGLU MoE Mamba + Attention interleaved
Gemma 2 Decoder-only RoPE GQA + local/global GeGLU Alternating local/global attention

Key Takeaways

Check Your Understanding

1. Why did decoder-only models become dominant over encoder-decoder models for general-purpose LLMs?

Show Answer
Decoder-only models have a simpler training objective (next-token prediction), scale straightforwardly, and can handle any task through prompting (in-context learning). The auto-regressive structure naturally supports generation, and prompting eliminates the need for task-specific architectural modifications.

2. How does FlashAttention achieve speedup without approximating attention?

Show Answer
FlashAttention fuses the attention computation into a single GPU kernel that keeps intermediate results (the attention matrix) in fast on-chip SRAM rather than writing them to slow HBM (GPU global memory). By tiling the computation and using online softmax (computing softmax incrementally), it avoids materializing the full T x T attention matrix in HBM.

3. What is the fundamental tradeoff between attention and SSMs?

Show Answer
Attention gives each token direct access to every other token (lossless, but O(T^2) cost). SSMs compress the sequence history into a fixed-size state vector (lossy compression, but O(T) cost). Attention excels at tasks requiring precise retrieval from context; SSMs excel at tasks requiring efficient processing of very long sequences.

4. In Mixtral 8x7B, how many parameters are active per token, and why does this matter?

Show Answer
Only about 12.9B of the 46.7B total parameters are active per token (top-2 of 8 experts per layer). This matters because inference speed depends on active parameters (FLOPs per token), not total parameters. So Mixtral runs at roughly the speed of a 13B dense model while having the knowledge capacity of a much larger model.

5. How does MLA differ from GQA in reducing KV cache size?

Show Answer
GQA reduces the KV cache by sharing key-value heads across groups of query heads (fewer distinct KV vectors to store). MLA takes a different approach: it compresses all key-value information into a low-dimensional latent vector and stores only that compressed representation. The full K and V are reconstructed on the fly from the latent. MLA can achieve higher compression ratios with less quality degradation than aggressive GQA/MQA.