Module 04, Section 4.4

GPU Fundamentals & Systems

Understanding the hardware that makes Transformers possible, from memory hierarchies to Triton kernels.

My GPU utilization hit 100% and for one beautiful moment, memory bandwidth and compute were perfectly balanced. Then I loaded the next batch.

Thermal Throttle Theo, a system-level optimizer

1. Why GPU Architecture Matters

A Transformer's theoretical FLOP count tells only half the story. Two implementations with identical FLOP counts can differ by 10x in wall-clock time depending on how well they utilize the GPU's memory hierarchy. Understanding GPU architecture is not optional for anyone who wants to train or serve LLMs efficiently. This section provides the mental model you need to reason about performance bottlenecks and understand optimizations like FlashAttention.

Big Picture

Modern GPUs are fundamentally throughput machines designed to keep thousands of cores busy simultaneously. The main challenge is not compute; it is feeding data to those cores fast enough. Most Transformer operations are memory-bound, meaning they spend more time moving data than computing with it.

2. GPU Architecture Overview

2.1 Streaming Multiprocessors (SMs)

A GPU is organized as an array of Streaming Multiprocessors (SMs). Each SM contains:

2.2 Memory Hierarchy

Registers ~20 TB/s | 256 KB/SM | ~1 cycle Shared Memory / L1 Cache ~19 TB/s | 228 KB/SM | ~30 cycles L2 Cache ~5 TB/s | 50 MB (H100) | ~200 cycles HBM (Global Memory) 3.35 TB/s (H100) | 80 GB | ~400 cycles CPU System RAM: ~50 GB/s | up to 2 TB | ~10,000 cycles FAST SLOW SMALL LARGE
Figure 4.6: GPU memory hierarchy. Each level is faster but smaller than the one below. The key optimization challenge is keeping data in the fast upper levels.

The bandwidth gap between shared memory (~19 TB/s) and HBM (~3.35 TB/s on H100) is roughly 6x. The gap between registers and HBM is even larger. This is why memory access patterns dominate GPU performance. An operation that reads the same data from HBM three times (like naive attention) can be 3x slower than one that reads it once and keeps it in shared memory (FlashAttention).

3. Compute-Bound vs. Memory-Bound Operations

3.1 The Roofline Model

The roofline model characterizes each operation by its arithmetic intensity: the ratio of FLOPs to bytes transferred from memory. If an operation does many FLOPs per byte loaded, it is compute-bound (limited by the number of arithmetic units). If it does few FLOPs per byte, it is memory-bound (limited by memory bandwidth).

Arithmetic Intensity = FLOPs / Bytes transferred
OperationFLOPsMemoryIntensityBound
Matrix multiply (large) 2MNK 2(MK + KN + MN) bytes High (scales with dims) Compute
LayerNorm ~5T per element Read + write all elements ~2.5 Memory
Softmax ~5T per element Read + write all elements ~2.5 Memory
Dropout ~1 per element Read + write all elements ~0.5 Memory
Element-wise add (residual) 1 per element Read 2 + write 1 ~0.33 Memory
Attention (QKT softmax V) O(T2d) O(T2) for attention matrix Depends on T, d Usually memory
Key Insight: Most Operations Are Memory-Bound

In a typical Transformer forward pass, the large matrix multiplications (QKV projections, FFN layers, output projection) are compute-bound and keep the GPU busy. But everything else (LayerNorm, softmax, dropout, residual adds, attention score computation for moderate sequence lengths) is memory-bound. This is why kernel fusion (combining multiple memory-bound operations into a single kernel) is so effective.

4. The FlashAttention Algorithm

FlashAttention is the single most important GPU optimization for Transformers. It computes exact standard attention while reducing HBM reads/writes from O(T2) to O(T2d/M), where M is the size of on-chip SRAM. For typical values, this is a 5 to 10x reduction in memory traffic.

4.1 The Problem with Naive Attention

The standard attention implementation performs these steps, each reading from and writing to HBM:

  1. Compute S = QKT / √d, write S to HBM. Size: O(T2).
  2. Read S from HBM, apply mask, compute P = softmax(S), write P to HBM. Size: O(T2).
  3. Apply dropout to P, write back to HBM.
  4. Read P from HBM, compute O = PV, write O to HBM.

The T × T matrices S and P are the bottleneck. For T=4096, d=128, each matrix is 64 MB in FP32 per head per batch element. With 32 heads and batch size 4, that is 8 GB just for the intermediate attention matrices.

4.2 The Tiling Approach

FlashAttention processes the attention computation in tiles. Instead of computing the full T × T attention matrix at once, it processes blocks of size Br × Bc that fit in SRAM. The challenge is computing the softmax correctly when you only see part of each row at a time.

Online Softmax

The critical algorithmic trick is online softmax: computing the softmax incrementally as new blocks arrive. For a row of attention scores being computed in blocks, the algorithm maintains running statistics (the current maximum and the running sum of exponentials) and rescales previous partial results as new maxima are discovered:

# Pseudocode: Online softmax for FlashAttention
# Processing one row of the attention matrix in blocks

max_so_far = -infinity
sum_exp = 0
output_accumulator = zeros(d_v)

for each block j of keys/values:
    # Compute attention scores for this block
    scores_j = query @ keys_block_j.T / sqrt(d_k)

    # Update running max
    new_max = max(max_so_far, max(scores_j))

    # Rescale previous accumulator with correction factor
    correction = exp(max_so_far - new_max)
    sum_exp = sum_exp * correction + sum(exp(scores_j - new_max))
    output_accumulator = output_accumulator * correction + exp(scores_j - new_max) @ values_block_j

    max_so_far = new_max

# Final normalization
output = output_accumulator / sum_exp

This is numerically equivalent to computing the full softmax but requires only O(B) SRAM at any point, rather than the full O(T) row. The correction factor ensures that as we discover larger values, all previous exponentials are rescaled consistently.

Q Block 1 Block 2 Block 3 Block 4 KT Block 1 Block 2 Block 3 Block 4 Attention Matrix (never fully materialized in HBM) Computing Processing order: 1. Load Q block, K block 2. Compute tile of scores 3. Update running softmax 4. Accumulate output tile 5. Move to next K block All in SRAM!
Figure 4.7: FlashAttention tiles the attention computation into blocks that fit in on-chip SRAM. The full T x T attention matrix is never materialized in HBM.

5. Introduction to Triton

Writing GPU kernels in CUDA requires managing threads, warps, shared memory, synchronization, and memory coalescing at a low level. Triton (developed at OpenAI) provides a higher-level abstraction: you write kernels in a Python-like language that operates on blocks of data rather than individual threads. Triton handles the complex details of thread mapping, shared memory management, and memory coalescing automatically.

5.1 A Simple Example: Vector Addition

import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    # Each program instance handles BLOCK_SIZE elements
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
    mask = offsets < n_elements

    # Load, compute, store
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)


def triton_add(x, y):
    output = torch.empty_like(x)
    n_elements = x.numel()
    BLOCK_SIZE = 1024

    # Launch kernel with enough program instances
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE)
    return output

Notice: no thread management, no shared memory allocation, no warp-level operations. Triton programs are specified in terms of program instances that each process a block of elements. The compiler handles the mapping to GPU hardware.

5.2 Lab: Fused Softmax in Triton

Lab: Fused Softmax Kernel

Implementing softmax as a fused kernel eliminates the need for multiple HBM round-trips. In a naive implementation, computing softmax requires three passes over the data: one to find the max, one to compute exponentials and their sum, and one to normalize. A fused kernel does all three in a single pass through the data in SRAM.

@triton.jit
def softmax_kernel(
    output_ptr, input_ptr,
    input_row_stride, output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """Fused softmax: one kernel, one pass through HBM per row."""
    # Each program handles one row
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # Load the row into SRAM
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    row = tl.load(row_start_ptr + col_offsets, mask=mask, other=float('-inf'))

    # Compute softmax entirely in SRAM
    # Step 1: Subtract max for numerical stability
    row_max = tl.max(row, axis=0)
    row = row - row_max

    # Step 2: Exponentiate
    numerator = tl.exp(row)

    # Step 3: Normalize
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # Write back to HBM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=mask)


def fused_softmax(x):
    """Apply softmax to each row of x using a fused Triton kernel."""
    n_rows, n_cols = x.shape
    # BLOCK_SIZE must be a power of 2 >= n_cols
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    output = torch.empty_like(x)
    softmax_kernel[(n_rows,)](
        output, x,
        x.stride(0), output.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return output

This fused kernel reads each row from HBM exactly once, performs the entire softmax in SRAM, and writes the result back once. The naive PyTorch implementation (torch.softmax) may require multiple reads and writes for the intermediate max and sum computations. On an A100, a fused softmax kernel can be 2 to 4x faster than the PyTorch default for typical Transformer shapes.

6. Key GPU Metrics for LLM Practitioners

GPUHBM CapacityHBM BandwidthFP16 TFLOPsTF32 TFLOPs
A100 (80GB)80 GB2.0 TB/s312156
H100 SXM80 GB3.35 TB/s990495
H200141 GB4.8 TB/s990495
B200192 GB8.0 TB/s22501125
Note: Model Flops Utilization (MFU)

MFU measures what fraction of the GPU's theoretical peak FLOPS your training run actually achieves. Good Transformer training typically reaches 40% to 60% MFU. Reaching above 60% is excellent. Values below 30% usually indicate a bottleneck (memory-bound operations, communication overhead, or poor batch size selection). The Chinchilla paper reported ~57% MFU for their training runs.

7. Practical Considerations

7.1 Mixed Precision Training

Modern LLM training uses mixed precision: forward and backward passes use FP16 or BF16, while the master weights and optimizer states are kept in FP32. BF16 is preferred because it has the same exponent range as FP32 (avoiding overflow) but lower precision (8 vs. 23 mantissa bits). BF16 Tensor Core operations are natively supported on A100 and later GPUs.

7.2 Memory Budgeting

Training a Transformer model requires storing: (1) model parameters, (2) optimizer states (Adam stores momentum and variance, tripling the parameter memory), (3) gradients, and (4) activations (for the backward pass). The dominant cost for large models is often optimizer states. For a 7B parameter model in mixed precision:

This is why training a 7B model requires multiple GPUs even when the model parameters alone would fit on a single 80 GB GPU. Techniques like ZeRO (distributed optimizer states), gradient checkpointing (recomputing activations instead of storing them), and offloading help manage this memory pressure.

Key Takeaways

Check Your Understanding

1. Why is standard attention considered memory-bound rather than compute-bound?

Show Answer
The T x T attention matrix must be written to and read from HBM multiple times (for the QK^T product, masking, softmax, dropout, and the final multiplication with V). The total bytes transferred (O(T^2) for the attention matrix) exceed what the memory bandwidth can deliver before the compute units would finish their work. FlashAttention addresses this by keeping the attention matrix in fast on-chip SRAM.

2. What is the "online softmax" trick and why is it needed for FlashAttention?

Show Answer
Online softmax computes the softmax incrementally as new blocks of the attention row arrive. It maintains running statistics (current max, running sum of exponentials) and rescales previous partial results when a new maximum is found. This is needed because FlashAttention processes the attention matrix in tiles; it never has the complete row available at once, so it cannot compute the global max and sum in advance.

3. Why does mixed precision training use BF16 rather than FP16?

Show Answer
BF16 has the same exponent range (8 bits) as FP32, which prevents overflow and underflow during training. FP16 has only 5 exponent bits, giving a much smaller dynamic range that can cause training instabilities (loss scaling is required). BF16 trades mantissa precision (8 bits vs. FP16's 10 bits) for this exponent range, which is a good tradeoff since gradient and activation values need wide range more than high precision.

4. A GPU has 3.35 TB/s HBM bandwidth and 990 TFLOPS of FP16 compute. What is the arithmetic intensity threshold for an operation to be compute-bound?

Show Answer
The threshold is 990 TFLOPS / 3.35 TB/s = ~295 FLOPs per byte. Operations with arithmetic intensity above 295 FLOPs/byte are compute-bound; below are memory-bound. For reference, large matrix multiplies can easily reach 1000+ FLOPs/byte, while element-wise operations have ~0.5 to 5 FLOPs/byte.

5. Why does training a 7B parameter model require much more than 14 GB of GPU memory?

Show Answer
Beyond the 14 GB for parameters (in BF16), training also requires: optimizer states (AdamW stores two FP32 states per parameter, adding ~56 GB), gradients (~14 GB in BF16), and activations (variable, can be many GB depending on batch size and sequence length). The total minimum is roughly 84 GB before activations, which is why multi-GPU training and memory optimization techniques like ZeRO and gradient checkpointing are essential.