Module 13 · Section 13.3

Supervised Fine-Tuning (SFT)

Running full fine-tuning with Hugging Face Trainer and TRL, selecting hyperparameters, and monitoring training progress with W&B and TensorBoard
★ Big Picture

Supervised fine-tuning (SFT) is the core technique for teaching a pre-trained model to follow instructions and produce specific outputs. In SFT, you train on input/output pairs where the loss is computed only on the output tokens (the assistant's response), not the input tokens (the user's prompt). This section walks through the complete SFT workflow using Hugging Face's TRL library, from loading the model to selecting hyperparameters, configuring gradient accumulation, and monitoring training with Weights & Biases and TensorBoard.

1. The SFT Training Loop

At its core, SFT modifies the standard causal language modeling objective in one important way: the loss is masked so that gradient updates come only from predicting the assistant's response tokens. The model still sees the full conversation during the forward pass (for context), but only the response tokens contribute to the loss. This teaches the model what to generate rather than what to predict about the user's input.

SFT Loss Masking: Only Compute Loss on Response Tokens <|user|> What is photosynthesis? <|assistant|> Photosynthesis is the process by which plants convert light... Labels (for loss computation): -100 -100 -100 -100 -100 -100 -100 Photo synth esis is the process by which ... Ignored (label = -100) No gradient from prompt tokens Loss computed here Model learns to generate these tokens
Figure 13.6: In SFT, the loss is computed only on response tokens. Prompt tokens are masked with label -100.

1.1 Complete SFT Script with TRL

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig

# 1. Load model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure pad token is set (required for batched training)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",  # Faster training
)

# 2. Load and prepare dataset (ChatML/messages format)
dataset = load_dataset("json", data_files={
    "train": "data/train.jsonl",
    "validation": "data/val.jsonl"
})

# 3. Configure SFT training
sft_config = SFTConfig(
    output_dir="./checkpoints/llama-sft",

    # Core hyperparameters
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,    # Effective batch = 4 * 8 = 32
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,                 # 10% of steps for warmup
    lr_scheduler_type="cosine",

    # Sequence configuration
    max_seq_length=2048,
    packing=True,                     # Enable sequence packing

    # Precision and optimization
    bf16=True,                        # Use bfloat16 mixed precision
    gradient_checkpointing=True,      # Trade compute for memory
    gradient_checkpointing_kwargs={"use_reentrant": False},

    # Logging and evaluation
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,               # Keep only 3 best checkpoints
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",

    # Monitoring
    report_to="wandb",                # or "tensorboard"
    run_name="llama-8b-sft-v1",

    # Reproducibility
    seed=42,
    data_seed=42,
)

# 4. Create trainer and start training
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    processing_class=tokenizer,
)

# 5. Train
trainer.train()

# 6. Save the final model
trainer.save_model("./models/llama-sft-final")
tokenizer.save_pretrained("./models/llama-sft-final")
📝 Note

Flash Attention 2 is essential for efficient training. Setting attn_implementation="flash_attention_2" reduces memory usage by 2x to 4x and speeds up training by 1.5x to 2x compared to standard attention. Install it with pip install flash-attn. It requires an NVIDIA GPU with Ampere architecture (A100, RTX 3090) or newer.

2. Hyperparameter Selection

Choosing the right hyperparameters is critical for fine-tuning success. Unlike pre-training where you have billions of tokens to recover from suboptimal choices, fine-tuning datasets are small and training runs are short, making the model sensitive to hyperparameter values. The following table provides tested starting points for common model sizes.

2.1 Recommended Starting Points

Hyperparameter1B to 3B Models7B to 13B Models30B to 70B Models
Learning rate5e-5 to 1e-41e-5 to 5e-55e-6 to 2e-5
Effective batch size16 to 3232 to 6464 to 128
Epochs3 to 52 to 31 to 2
Warmup ratio0.05 to 0.10.03 to 0.10.03 to 0.05
Weight decay0.01 to 0.10.010.01
Max grad norm1.01.01.0
LR schedulercosinecosinecosine
🔑 Key Insight

Learning rate is the most important hyperparameter. If you only tune one thing, tune the learning rate. Too high and you get catastrophic forgetting; too low and the model barely changes. A good rule of thumb: start with 2e-5 for 7B+ models and 5e-5 for smaller models. If validation loss increases during training, reduce the learning rate by 2x to 3x. If the model is not learning (loss plateau after warmup), increase it by 2x.

2.2 Gradient Accumulation

Gradient accumulation lets you simulate large batch sizes on limited GPU memory. Instead of processing the entire batch at once, you process smaller micro-batches and accumulate the gradients before performing a single optimizer step. The effective batch size is the product of per-device batch size, gradient accumulation steps, and the number of GPUs.

# Calculating effective batch size
def compute_effective_batch_size(
    per_device_batch_size: int,
    gradient_accumulation_steps: int,
    num_gpus: int = 1
) -> dict:
    """Calculate effective batch size and training throughput."""
    effective_batch = per_device_batch_size * gradient_accumulation_steps * num_gpus

    return {
        "per_device_batch_size": per_device_batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "num_gpus": num_gpus,
        "effective_batch_size": effective_batch,
        "optimizer_steps_per_epoch": "num_examples / effective_batch_size",
    }

# Common configurations
configs = [
    (2, 16, 1),   # Single GPU, small memory
    (4, 8, 1),    # Single GPU, moderate memory
    (4, 4, 4),    # 4 GPUs, distributed
    (8, 2, 8),    # 8 GPUs, large cluster
]

print(f"{'Per-device':>12} {'Grad Accum':>12} {'GPUs':>6} {'Effective BS':>14}")
print("-" * 50)
for pd_bs, ga, gpus in configs:
    result = compute_effective_batch_size(pd_bs, ga, gpus)
    print(f"{pd_bs:>12} {ga:>12} {gpus:>6} {result['effective_batch_size']:>14}")
Per-device Grad Accum GPUs Effective BS -------------------------------------------------- 2 16 1 32 4 8 1 32 4 4 4 64 8 2 8 128

3. Learning Rate Schedulers

The learning rate scheduler controls how the learning rate changes over the course of training. For fine-tuning, cosine annealing with warmup is the most commonly used and generally most effective scheduler. The warmup phase gradually increases the learning rate from zero to prevent early instability, while the cosine decay smoothly reduces it to help the model converge.

Learning Rate Schedulers for Fine-Tuning Training Steps Learning Rate 0 2e-5 Cosine (recommended) Linear decay Constant Warmup
Figure 13.7: Cosine annealing with warmup is the default scheduler for SFT. Warmup prevents early instability.
# Visualizing different schedulers
from transformers import get_scheduler
import torch

def visualize_schedulers(
    total_steps: int = 1000,
    warmup_steps: int = 100,
    learning_rate: float = 2e-5
):
    """Compare learning rate schedules side by side."""
    schedules = {}

    for sched_type in ["cosine", "linear", "constant_with_warmup"]:
        # Create a dummy optimizer
        param = torch.nn.Parameter(torch.zeros(1))
        optimizer = torch.optim.AdamW([param], lr=learning_rate)

        scheduler = get_scheduler(
            name=sched_type,
            optimizer=optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
        )

        lrs = []
        for step in range(total_steps):
            lrs.append(optimizer.param_groups[0]["lr"])
            optimizer.step()
            scheduler.step()

        schedules[sched_type] = lrs

    # Print sample points
    checkpoints = [0, 50, 100, 250, 500, 750, 999]
    print(f"{'Step':>6} {'Cosine':>12} {'Linear':>12} {'Constant':>12}")
    print("-" * 44)
    for step in checkpoints:
        print(f"{step:>6} {schedules['cosine'][step]:>12.2e} "
              f"{schedules['linear'][step]:>12.2e} "
              f"{schedules['constant_with_warmup'][step]:>12.2e}")

    return schedules

schedules = visualize_schedulers()
Step Cosine Linear Constant -------------------------------------------- 0 0.00e+00 0.00e+00 0.00e+00 50 1.00e-05 1.00e-05 1.00e-05 100 2.00e-05 2.00e-05 2.00e-05 250 1.91e-05 1.67e-05 2.00e-05 500 1.00e-05 1.11e-05 2.00e-05 750 2.45e-06 5.56e-06 2.00e-05 999 2.47e-08 2.22e-08 2.00e-05

4. Monitoring Training

Monitoring is not optional for fine-tuning. Without proper monitoring, you cannot detect overfitting, catastrophic forgetting, or training instability until after a costly training run completes. The two most widely used tools are Weights & Biases (W&B) and TensorBoard.

4.1 Key Metrics to Track

MetricWhat It Tells YouWarning Signs
Training lossHow well the model fits the training dataLoss spikes, NaN values, loss not decreasing
Validation lossGeneralization abilityDiverges from train loss (overfitting)
Learning rateCurrent LR value from schedulerShould match expected schedule shape
Gradient normMagnitude of gradientsExploding (>10) or vanishing (<1e-6)
GPU memoryMemory utilizationOOM errors, memory leaks over time
Tokens per secondTraining throughputSudden drops indicate bottlenecks

4.2 W&B Integration

import wandb
from transformers import TrainerCallback

# Initialize W&B project
wandb.init(
    project="llm-fine-tuning",
    name="llama-8b-sft-medical-v1",
    config={
        "model": "meta-llama/Llama-3.1-8B-Instruct",
        "dataset": "medical-qa-10k",
        "learning_rate": 2e-5,
        "epochs": 3,
        "effective_batch_size": 32,
    }
)

class DetailedLoggingCallback(TrainerCallback):
    """Custom callback for detailed training metrics."""

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        # Log additional computed metrics
        if "loss" in logs and "eval_loss" in logs:
            gap = logs["eval_loss"] - logs["loss"]
            wandb.log({
                "train_eval_gap": gap,
                "overfitting_signal": gap > 0.1,
            }, step=state.global_step)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is None:
            return

        # Log evaluation metrics with clear names
        eval_metrics = {
            f"eval/{k.replace('eval_', '')}": v
            for k, v in metrics.items()
        }
        wandb.log(eval_metrics, step=state.global_step)

    def on_train_end(self, args, state, control, **kwargs):
        # Log final training summary
        wandb.log({
            "final/total_steps": state.global_step,
            "final/best_eval_loss": state.best_metric,
            "final/epochs_completed": state.epoch,
        })

# Add callback to trainer
# trainer = SFTTrainer(..., callbacks=[DetailedLoggingCallback()])

4.3 TensorBoard Alternative

# TensorBoard setup (alternative to W&B)
# In SFTConfig, set: report_to="tensorboard"

# Launch TensorBoard in a separate terminal:
# tensorboard --logdir ./checkpoints/llama-sft/runs

# Custom TensorBoard logging
from torch.utils.tensorboard import SummaryWriter

class TensorBoardCallback(TrainerCallback):
    """Custom TensorBoard logging with generation samples."""

    def __init__(self, log_dir="./tb_logs"):
        self.writer = SummaryWriter(log_dir)

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return
        for key, value in logs.items():
            if isinstance(value, (int, float)):
                self.writer.add_scalar(f"training/{key}", value, state.global_step)

    def on_evaluate(self, args, state, control, metrics=None, model=None, **kwargs):
        if metrics:
            for key, value in metrics.items():
                if isinstance(value, (int, float)):
                    self.writer.add_scalar(f"eval/{key}", value, state.global_step)

    def on_train_end(self, args, state, control, **kwargs):
        self.writer.close()
⚠ Warning

Watch the train/eval loss gap. If training loss keeps decreasing but validation loss starts increasing, you are overfitting. Stop training and use the best checkpoint (the one with lowest validation loss). With load_best_model_at_end=True and metric_for_best_model="eval_loss" in your config, the trainer will automatically select the best checkpoint at the end of training.

5. Common Issues and Debugging

5.1 Troubleshooting Guide

SFT Troubleshooting Decision Tree Loss not decreasing? LR too low? Increase 2x to 5x Data issue? Check labels/format Loss spikes or NaN? LR too high? Reduce by 3x to 10x Reduce max_grad_norm to 0.3 Eval loss diverging? Overfitting: fewer epochs, more data Add weight_decay=0.1, dropout=0.1
Figure 13.8: Common SFT issues and their solutions
🔑 Key Insight

Always run a sanity check before full training. Train for 10 to 20 steps on a tiny subset (50 to 100 examples) and verify that: (1) the loss decreases, (2) no OOM errors occur, (3) the tokenized examples look correct when decoded, and (4) the generated outputs are not garbled. This 5-minute check can save hours of wasted compute.

# Sanity check: verify training is working correctly
def run_sanity_check(trainer, tokenizer, dataset, num_samples=3):
    """Quick sanity check before committing to a full training run."""

    print("=" * 60)
    print("SANITY CHECK")
    print("=" * 60)

    # 1. Check a few tokenized examples
    print("\n1. Sample tokenized examples:")
    for i in range(min(num_samples, len(dataset))):
        example = dataset[i]
        messages = example["messages"]
        text = tokenizer.apply_chat_template(messages, tokenize=False)
        tokens = tokenizer(text)["input_ids"]
        print(f"   Example {i}: {len(tokens)} tokens")
        # Decode and check it looks reasonable
        decoded = tokenizer.decode(tokens[:50])
        print(f"   First 50 tokens: {decoded[:100]}...")

    # 2. Run a few training steps
    print("\n2. Running 10 training steps...")
    trainer.args.max_steps = 10
    trainer.args.logging_steps = 1
    result = trainer.train()

    # 3. Check loss trajectory
    logs = trainer.state.log_history
    losses = [l["loss"] for l in logs if "loss" in l]
    print(f"   Loss trajectory: {[f'{l:.4f}' for l in losses]}")

    if len(losses) >= 2:
        if losses[-1] < losses[0]:
            print("   [PASS] Loss is decreasing")
        else:
            print("   [WARN] Loss is not decreasing; check learning rate")

    # 4. Generate a sample response
    print("\n3. Sample generation:")
    model = trainer.model
    model.eval()
    test_messages = [{"role": "user", "content": "Hello, how are you?"}]
    inputs = tokenizer.apply_chat_template(
        test_messages, return_tensors="pt", add_generation_prompt=True
    ).to(model.device)

    with torch.no_grad():
        output = model.generate(inputs, max_new_tokens=50, temperature=0.7)

    response = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)
    print(f"   Response: {response[:200]}")

    print("\n" + "=" * 60)
    print("Sanity check complete. Review above before full training.")
    print("=" * 60)

Section 13.3 Quiz

1. In SFT, why is the loss computed only on the assistant's response tokens and not on the user's prompt tokens?
Show Answer
The goal of SFT is to teach the model what to generate in response to user inputs, not to predict user inputs. Computing loss on prompt tokens would waste gradient updates trying to improve the model's ability to predict user messages, which is irrelevant to the task. By masking prompt tokens (setting labels to -100), all learning signal comes from improving the quality of generated responses.
2. If you have a per-device batch size of 4 and 8 gradient accumulation steps on a single GPU, what is the effective batch size?
Show Answer
The effective batch size is 4 (per-device) x 8 (gradient accumulation) x 1 (GPU) = 32. The model processes 4 examples per forward pass, accumulates gradients over 8 such passes, and then performs one optimizer step using the averaged gradients from all 32 examples.
3. You observe that training loss is decreasing but validation loss starts increasing after 200 steps. What is happening and what should you do?
Show Answer
This is overfitting: the model is memorizing the training data rather than learning generalizable patterns. Solutions: (1) Stop training at step 200 and use that checkpoint (set load_best_model_at_end=True). (2) Add regularization (increase weight_decay to 0.05 to 0.1). (3) Reduce the number of epochs. (4) Augment the training dataset with more diverse examples. (5) Try a lower learning rate to slow the rate of weight updates.
4. Why is cosine annealing preferred over a constant learning rate for fine-tuning?
Show Answer
Cosine annealing smoothly reduces the learning rate from its peak to near zero over the course of training. This allows the model to make large updates early (when it is far from the optimum) and progressively smaller updates later (for fine-grained convergence). A constant learning rate risks overshooting the optimum in later stages, leading to oscillation and worse final performance. The warmup phase at the start prevents instability from large early gradients on the new data distribution.
5. What is the purpose of gradient checkpointing, and what is its tradeoff?
Show Answer
Gradient checkpointing reduces GPU memory usage during training by not storing all intermediate activations from the forward pass. Instead, it recomputes activations during the backward pass as needed. The tradeoff is compute for memory: training is approximately 20% to 30% slower because of the recomputation, but memory usage decreases significantly (roughly 2x to 3x reduction). This enables fine-tuning larger models on GPUs with limited memory.

Key Takeaways