Supervised fine-tuning (SFT) is the core technique for teaching a pre-trained model to follow instructions and produce specific outputs. In SFT, you train on input/output pairs where the loss is computed only on the output tokens (the assistant's response), not the input tokens (the user's prompt). This section walks through the complete SFT workflow using Hugging Face's TRL library, from loading the model to selecting hyperparameters, configuring gradient accumulation, and monitoring training with Weights & Biases and TensorBoard.
1. The SFT Training Loop
At its core, SFT modifies the standard causal language modeling objective in one important way: the loss is masked so that gradient updates come only from predicting the assistant's response tokens. The model still sees the full conversation during the forward pass (for context), but only the response tokens contribute to the loss. This teaches the model what to generate rather than what to predict about the user's input.
1.1 Complete SFT Script with TRL
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
# 1. Load model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure pad token is set (required for batched training)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2", # Faster training
)
# 2. Load and prepare dataset (ChatML/messages format)
dataset = load_dataset("json", data_files={
"train": "data/train.jsonl",
"validation": "data/val.jsonl"
})
# 3. Configure SFT training
sft_config = SFTConfig(
output_dir="./checkpoints/llama-sft",
# Core hyperparameters
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8, # Effective batch = 4 * 8 = 32
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.1, # 10% of steps for warmup
lr_scheduler_type="cosine",
# Sequence configuration
max_seq_length=2048,
packing=True, # Enable sequence packing
# Precision and optimization
bf16=True, # Use bfloat16 mixed precision
gradient_checkpointing=True, # Trade compute for memory
gradient_checkpointing_kwargs={"use_reentrant": False},
# Logging and evaluation
logging_steps=10,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=3, # Keep only 3 best checkpoints
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
# Monitoring
report_to="wandb", # or "tensorboard"
run_name="llama-8b-sft-v1",
# Reproducibility
seed=42,
data_seed=42,
)
# 4. Create trainer and start training
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
processing_class=tokenizer,
)
# 5. Train
trainer.train()
# 6. Save the final model
trainer.save_model("./models/llama-sft-final")
tokenizer.save_pretrained("./models/llama-sft-final")
Flash Attention 2 is essential for efficient training. Setting attn_implementation="flash_attention_2" reduces memory usage by 2x to 4x and speeds up training by 1.5x to 2x compared to standard attention. Install it with pip install flash-attn. It requires an NVIDIA GPU with Ampere architecture (A100, RTX 3090) or newer.
2. Hyperparameter Selection
Choosing the right hyperparameters is critical for fine-tuning success. Unlike pre-training where you have billions of tokens to recover from suboptimal choices, fine-tuning datasets are small and training runs are short, making the model sensitive to hyperparameter values. The following table provides tested starting points for common model sizes.
2.1 Recommended Starting Points
| Hyperparameter | 1B to 3B Models | 7B to 13B Models | 30B to 70B Models |
|---|---|---|---|
| Learning rate | 5e-5 to 1e-4 | 1e-5 to 5e-5 | 5e-6 to 2e-5 |
| Effective batch size | 16 to 32 | 32 to 64 | 64 to 128 |
| Epochs | 3 to 5 | 2 to 3 | 1 to 2 |
| Warmup ratio | 0.05 to 0.1 | 0.03 to 0.1 | 0.03 to 0.05 |
| Weight decay | 0.01 to 0.1 | 0.01 | 0.01 |
| Max grad norm | 1.0 | 1.0 | 1.0 |
| LR scheduler | cosine | cosine | cosine |
Learning rate is the most important hyperparameter. If you only tune one thing, tune the learning rate. Too high and you get catastrophic forgetting; too low and the model barely changes. A good rule of thumb: start with 2e-5 for 7B+ models and 5e-5 for smaller models. If validation loss increases during training, reduce the learning rate by 2x to 3x. If the model is not learning (loss plateau after warmup), increase it by 2x.
2.2 Gradient Accumulation
Gradient accumulation lets you simulate large batch sizes on limited GPU memory. Instead of processing the entire batch at once, you process smaller micro-batches and accumulate the gradients before performing a single optimizer step. The effective batch size is the product of per-device batch size, gradient accumulation steps, and the number of GPUs.
# Calculating effective batch size
def compute_effective_batch_size(
per_device_batch_size: int,
gradient_accumulation_steps: int,
num_gpus: int = 1
) -> dict:
"""Calculate effective batch size and training throughput."""
effective_batch = per_device_batch_size * gradient_accumulation_steps * num_gpus
return {
"per_device_batch_size": per_device_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"num_gpus": num_gpus,
"effective_batch_size": effective_batch,
"optimizer_steps_per_epoch": "num_examples / effective_batch_size",
}
# Common configurations
configs = [
(2, 16, 1), # Single GPU, small memory
(4, 8, 1), # Single GPU, moderate memory
(4, 4, 4), # 4 GPUs, distributed
(8, 2, 8), # 8 GPUs, large cluster
]
print(f"{'Per-device':>12} {'Grad Accum':>12} {'GPUs':>6} {'Effective BS':>14}")
print("-" * 50)
for pd_bs, ga, gpus in configs:
result = compute_effective_batch_size(pd_bs, ga, gpus)
print(f"{pd_bs:>12} {ga:>12} {gpus:>6} {result['effective_batch_size']:>14}")
3. Learning Rate Schedulers
The learning rate scheduler controls how the learning rate changes over the course of training. For fine-tuning, cosine annealing with warmup is the most commonly used and generally most effective scheduler. The warmup phase gradually increases the learning rate from zero to prevent early instability, while the cosine decay smoothly reduces it to help the model converge.
# Visualizing different schedulers
from transformers import get_scheduler
import torch
def visualize_schedulers(
total_steps: int = 1000,
warmup_steps: int = 100,
learning_rate: float = 2e-5
):
"""Compare learning rate schedules side by side."""
schedules = {}
for sched_type in ["cosine", "linear", "constant_with_warmup"]:
# Create a dummy optimizer
param = torch.nn.Parameter(torch.zeros(1))
optimizer = torch.optim.AdamW([param], lr=learning_rate)
scheduler = get_scheduler(
name=sched_type,
optimizer=optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
lrs = []
for step in range(total_steps):
lrs.append(optimizer.param_groups[0]["lr"])
optimizer.step()
scheduler.step()
schedules[sched_type] = lrs
# Print sample points
checkpoints = [0, 50, 100, 250, 500, 750, 999]
print(f"{'Step':>6} {'Cosine':>12} {'Linear':>12} {'Constant':>12}")
print("-" * 44)
for step in checkpoints:
print(f"{step:>6} {schedules['cosine'][step]:>12.2e} "
f"{schedules['linear'][step]:>12.2e} "
f"{schedules['constant_with_warmup'][step]:>12.2e}")
return schedules
schedules = visualize_schedulers()
4. Monitoring Training
Monitoring is not optional for fine-tuning. Without proper monitoring, you cannot detect overfitting, catastrophic forgetting, or training instability until after a costly training run completes. The two most widely used tools are Weights & Biases (W&B) and TensorBoard.
4.1 Key Metrics to Track
| Metric | What It Tells You | Warning Signs |
|---|---|---|
| Training loss | How well the model fits the training data | Loss spikes, NaN values, loss not decreasing |
| Validation loss | Generalization ability | Diverges from train loss (overfitting) |
| Learning rate | Current LR value from scheduler | Should match expected schedule shape |
| Gradient norm | Magnitude of gradients | Exploding (>10) or vanishing (<1e-6) |
| GPU memory | Memory utilization | OOM errors, memory leaks over time |
| Tokens per second | Training throughput | Sudden drops indicate bottlenecks |
4.2 W&B Integration
import wandb
from transformers import TrainerCallback
# Initialize W&B project
wandb.init(
project="llm-fine-tuning",
name="llama-8b-sft-medical-v1",
config={
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset": "medical-qa-10k",
"learning_rate": 2e-5,
"epochs": 3,
"effective_batch_size": 32,
}
)
class DetailedLoggingCallback(TrainerCallback):
"""Custom callback for detailed training metrics."""
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Log additional computed metrics
if "loss" in logs and "eval_loss" in logs:
gap = logs["eval_loss"] - logs["loss"]
wandb.log({
"train_eval_gap": gap,
"overfitting_signal": gap > 0.1,
}, step=state.global_step)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics is None:
return
# Log evaluation metrics with clear names
eval_metrics = {
f"eval/{k.replace('eval_', '')}": v
for k, v in metrics.items()
}
wandb.log(eval_metrics, step=state.global_step)
def on_train_end(self, args, state, control, **kwargs):
# Log final training summary
wandb.log({
"final/total_steps": state.global_step,
"final/best_eval_loss": state.best_metric,
"final/epochs_completed": state.epoch,
})
# Add callback to trainer
# trainer = SFTTrainer(..., callbacks=[DetailedLoggingCallback()])
4.3 TensorBoard Alternative
# TensorBoard setup (alternative to W&B)
# In SFTConfig, set: report_to="tensorboard"
# Launch TensorBoard in a separate terminal:
# tensorboard --logdir ./checkpoints/llama-sft/runs
# Custom TensorBoard logging
from torch.utils.tensorboard import SummaryWriter
class TensorBoardCallback(TrainerCallback):
"""Custom TensorBoard logging with generation samples."""
def __init__(self, log_dir="./tb_logs"):
self.writer = SummaryWriter(log_dir)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
for key, value in logs.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"training/{key}", value, state.global_step)
def on_evaluate(self, args, state, control, metrics=None, model=None, **kwargs):
if metrics:
for key, value in metrics.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"eval/{key}", value, state.global_step)
def on_train_end(self, args, state, control, **kwargs):
self.writer.close()
Watch the train/eval loss gap. If training loss keeps decreasing but validation loss starts increasing, you are overfitting. Stop training and use the best checkpoint (the one with lowest validation loss). With load_best_model_at_end=True and metric_for_best_model="eval_loss" in your config, the trainer will automatically select the best checkpoint at the end of training.
5. Common Issues and Debugging
5.1 Troubleshooting Guide
Always run a sanity check before full training. Train for 10 to 20 steps on a tiny subset (50 to 100 examples) and verify that: (1) the loss decreases, (2) no OOM errors occur, (3) the tokenized examples look correct when decoded, and (4) the generated outputs are not garbled. This 5-minute check can save hours of wasted compute.
# Sanity check: verify training is working correctly
def run_sanity_check(trainer, tokenizer, dataset, num_samples=3):
"""Quick sanity check before committing to a full training run."""
print("=" * 60)
print("SANITY CHECK")
print("=" * 60)
# 1. Check a few tokenized examples
print("\n1. Sample tokenized examples:")
for i in range(min(num_samples, len(dataset))):
example = dataset[i]
messages = example["messages"]
text = tokenizer.apply_chat_template(messages, tokenize=False)
tokens = tokenizer(text)["input_ids"]
print(f" Example {i}: {len(tokens)} tokens")
# Decode and check it looks reasonable
decoded = tokenizer.decode(tokens[:50])
print(f" First 50 tokens: {decoded[:100]}...")
# 2. Run a few training steps
print("\n2. Running 10 training steps...")
trainer.args.max_steps = 10
trainer.args.logging_steps = 1
result = trainer.train()
# 3. Check loss trajectory
logs = trainer.state.log_history
losses = [l["loss"] for l in logs if "loss" in l]
print(f" Loss trajectory: {[f'{l:.4f}' for l in losses]}")
if len(losses) >= 2:
if losses[-1] < losses[0]:
print(" [PASS] Loss is decreasing")
else:
print(" [WARN] Loss is not decreasing; check learning rate")
# 4. Generate a sample response
print("\n3. Sample generation:")
model = trainer.model
model.eval()
test_messages = [{"role": "user", "content": "Hello, how are you?"}]
inputs = tokenizer.apply_chat_template(
test_messages, return_tensors="pt", add_generation_prompt=True
).to(model.device)
with torch.no_grad():
output = model.generate(inputs, max_new_tokens=50, temperature=0.7)
response = tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)
print(f" Response: {response[:200]}")
print("\n" + "=" * 60)
print("Sanity check complete. Review above before full training.")
print("=" * 60)
Section 13.3 Quiz
Show Answer
Show Answer
Show Answer
load_best_model_at_end=True). (2) Add regularization (increase weight_decay to 0.05 to 0.1). (3) Reduce the number of epochs. (4) Augment the training dataset with more diverse examples. (5) Try a lower learning rate to slow the rate of weight updates.Show Answer
Show Answer
Key Takeaways
- SFT loss is masked: only response tokens contribute to the loss; prompt tokens are labeled -100 and ignored during backpropagation.
- Start with 2e-5 learning rate for 7B+ models, use cosine annealing with 5% to 10% warmup, and train for 2 to 3 epochs as a baseline.
- Gradient accumulation lets you simulate large batch sizes on limited hardware; effective batch size = per_device_batch x grad_accum_steps x num_gpus.
- Enable Flash Attention 2 and gradient checkpointing to maximize training efficiency and minimize memory usage.
- Monitor both train and eval loss at every checkpoint; a growing gap signals overfitting and calls for early stopping or more regularization.
- Run a sanity check (10 to 20 steps on a small subset) before every full training run to catch data format issues, OOM errors, and learning rate problems early.