LLMs are expensive to pretrain and quickly become outdated. New knowledge emerges, domains evolve, and organizational needs change. Continual learning addresses how to update a model with new information without expensive retraining from scratch, while avoiding "catastrophic forgetting" of previously learned capabilities. This section covers continual pre-training on domain corpora, vocabulary extension for specialized terminology, replay methods, regularization techniques like Elastic Weight Consolidation, and progressive training strategies. These methods are essential for maintaining production LLMs that must adapt to changing requirements over time.
1. The Catastrophic Forgetting Problem
When you fine-tune or continue training an LLM on new domain-specific data, the model rapidly adapts to the new distribution but simultaneously degrades on its original capabilities. This phenomenon, called catastrophic forgetting, occurs because gradient updates that optimize for new data push the weights away from the regions that encode prior knowledge. The more you train on new data, the more aggressively the model forgets.
This is not just a theoretical concern. In practice, a model that undergoes continual pre-training on medical literature may lose its ability to write code, answer general knowledge questions, or follow instructions properly. The challenge is to absorb new domain knowledge while preserving the broad capabilities that make the model useful.
Catastrophic forgetting is more severe in full fine-tuning than in parameter-efficient methods. LoRA naturally mitigates forgetting because the base model weights remain frozen, and the low-rank adapter can only make limited modifications. This is one reason why LoRA-based continual learning is increasingly preferred over full-parameter continual pre-training for domain adaptation.
2. Continual Pre-Training
Continual pre-training (CPT) extends the standard pre-training objective (next-token prediction) on domain-specific corpora. Unlike instruction fine-tuning, which teaches the model to follow a new format, CPT injects new factual knowledge and domain vocabulary into the model's weights. This is the primary technique for creating domain-specific foundation models (for example, a medical LLM or a financial LLM).
2.1 Data Preparation for CPT
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import random
def prepare_cpt_dataset(
domain_data_path: str,
general_data_path: str,
replay_ratio: float = 0.1,
tokenizer_id: str = "meta-llama/Meta-Llama-3-8B",
max_seq_length: int = 4096,
):
"""Prepare CPT dataset with replay data to reduce forgetting."""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
# Load domain-specific data
domain_ds = load_dataset("text", data_files=domain_data_path, split="train")
print(f"Domain data: {len(domain_ds)} documents")
# Load general replay data (subset of original pretraining mix)
general_ds = load_dataset("text", data_files=general_data_path, split="train")
# Sample replay data proportionally
n_replay = int(len(domain_ds) * replay_ratio)
general_ds = general_ds.shuffle(seed=42).select(range(min(n_replay, len(general_ds))))
print(f"Replay data: {len(general_ds)} documents ({replay_ratio*100:.0f}%)")
# Combine and shuffle
combined = concatenate_datasets([domain_ds, general_ds])
combined = combined.shuffle(seed=42)
# Tokenize
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_seq_length,
padding=False,
)
tokenized = combined.map(tokenize, batched=True, remove_columns=["text"])
return tokenized
2.2 Training Configuration for CPT
Continual pre-training requires careful hyperparameter selection. The learning rate should be significantly lower than initial pre-training (typically 1e-5 to 5e-5, compared to 1e-4 to 3e-4 for original pre-training) to avoid destabilizing the pretrained weights. Training for too many epochs on domain data accelerates forgetting, so most practitioners use 1-2 passes over the domain corpus.
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
import torch
# Load base model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
# CPT-specific training arguments
training_args = TrainingArguments(
output_dir="./cpt-medical-llama",
num_train_epochs=1, # Typically 1-2 epochs for CPT
per_device_train_batch_size=4,
gradient_accumulation_steps=8, # Effective batch size = 32
learning_rate=2e-5, # Much lower than pre-training
lr_scheduler_type="cosine",
warmup_ratio=0.05,
weight_decay=0.1,
bf16=True,
logging_steps=50,
save_strategy="steps",
save_steps=500,
max_grad_norm=1.0,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False # Causal LM (next-token prediction)
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=cpt_dataset,
data_collator=data_collator,
)
trainer.train()
After continual pre-training, you typically need a second stage of instruction fine-tuning (SFT) to restore the model's ability to follow instructions. CPT optimizes for next-token prediction on raw text, which can degrade the model's chat/instruction-following behavior. The standard pipeline is: base model → CPT on domain data → SFT on domain-specific instructions.
3. Vocabulary Extension
Domain-specific corpora often contain terminology that the base tokenizer splits into many sub-word tokens. Medical terms like "electroencephalography" might become 5+ tokens, wasting context window capacity and reducing the model's ability to learn associations involving these terms. Vocabulary extension adds new tokens for frequently occurring domain terms, improving both efficiency and representation quality.
3.1 When to Extend Vocabulary
| Signal | Action | Example |
|---|---|---|
| Common domain terms split into 4+ tokens | Add as single token | "electroencephalography" → 1 token |
| Non-Latin scripts or specialized notation | Add character/symbol tokens | Chemical formulas, musical notation |
| High-frequency domain abbreviations | Add as tokens | "COPD", "MRI", "ICU" |
| Terms already tokenize efficiently (1-2 tokens) | Do not extend | "cancer", "patient", "diagnosis" |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
def extend_vocabulary(
model_id: str,
new_tokens: list[str],
init_strategy: str = "mean",
):
"""Extend model vocabulary with domain-specific tokens."""
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16
)
# Check current tokenization
for token in new_tokens[:5]:
ids = tokenizer.encode(token, add_special_tokens=False)
print(f" '{token}' -> {len(ids)} tokens: {ids}")
# Add new tokens to tokenizer
num_added = tokenizer.add_tokens(new_tokens)
print(f"Added {num_added} new tokens to vocabulary")
# Resize model embeddings
model.resize_token_embeddings(len(tokenizer))
# Initialize new token embeddings
if init_strategy == "mean":
# Initialize as mean of existing embeddings (stable default)
with torch.no_grad():
embeddings = model.get_input_embeddings()
existing_mean = embeddings.weight[:-num_added].mean(dim=0)
for i in range(num_added):
embeddings.weight[-num_added + i] = existing_mean
elif init_strategy == "subword":
# Initialize as mean of subword embeddings (better for compounds)
with torch.no_grad():
embeddings = model.get_input_embeddings()
for i, token in enumerate(new_tokens):
# Get original subword token IDs
old_ids = tokenizer.encode(token, add_special_tokens=False)
# Skip the newly added token ID
old_ids = [t for t in old_ids if t < len(tokenizer) - num_added]
if old_ids:
subword_mean = embeddings.weight[old_ids].mean(dim=0)
embeddings.weight[-num_added + i] = subword_mean
print(f"Vocabulary size: {len(tokenizer)}")
return model, tokenizer
# Example: extend for medical domain
medical_tokens = [
"electroencephalography", "immunohistochemistry",
"pharmacokinetics", "bronchoscopy",
"echocardiogram", "thrombocytopenia",
]
model, tokenizer = extend_vocabulary(
"meta-llama/Meta-Llama-3-8B",
medical_tokens,
init_strategy="subword",
)
After extending the vocabulary, you must continue training the model for the new embeddings to become useful. The initialized embeddings (whether from mean or subword strategies) are only approximations. Without further training, the new tokens will produce poor-quality outputs. Plan for at least several thousand gradient steps on data containing the new tokens before the model uses them effectively.
4. Replay Methods
Replay is the most straightforward defense against catastrophic forgetting: mix a portion of the original training data (or data from the same distribution) into the domain-specific training stream. By continuously exposing the model to general-domain examples, you anchor the weights near their original values for non-domain capabilities.
4.1 Replay Strategies
from torch.utils.data import Dataset, DataLoader
import random
class ReplayDataset(Dataset):
"""Dataset that mixes domain data with replay data."""
def __init__(self, domain_dataset, replay_dataset, replay_ratio=0.1):
self.domain = domain_dataset
self.replay = replay_dataset
self.replay_ratio = replay_ratio
# Calculate sizes for interleaving
self.total_size = len(domain_dataset)
self.n_replay_per_epoch = int(self.total_size * replay_ratio)
# Pre-sample replay indices for this epoch
self._resample_replay()
def _resample_replay(self):
"""Resample replay indices (call at each epoch start)."""
replay_indices = random.sample(
range(len(self.replay)),
min(self.n_replay_per_epoch, len(self.replay))
)
# Interleave: every ~(1/ratio) steps, insert a replay sample
self.schedule = []
replay_iter = iter(replay_indices)
interval = int(1 / self.replay_ratio) if self.replay_ratio > 0 else 999999
for i in range(self.total_size):
self.schedule.append(("domain", i))
if (i + 1) % interval == 0:
try:
self.schedule.append(("replay", next(replay_iter)))
except StopIteration:
pass
def __len__(self):
return len(self.schedule)
def __getitem__(self, idx):
source, data_idx = self.schedule[idx]
if source == "domain":
return self.domain[data_idx]
else:
return self.replay[data_idx]
5. Elastic Weight Consolidation (EWC)
Elastic Weight Consolidation (Kirkpatrick et al., 2017) adds a regularization term to the loss function that penalizes changes to parameters that were important for previous tasks. It estimates each parameter's "importance" using the Fisher Information Matrix, which measures how much the loss changes when a parameter is perturbed. Important parameters get a stronger regularization penalty, anchoring them near their original values while allowing less important parameters to adapt freely.
The EWC loss adds a quadratic penalty term:
Ltotal = Ltask + (λ / 2) · Σi Fi · (θi - θi*)2
where Fi is the Fisher information for parameter i, θi* is the original parameter value, and λ controls the regularization strength.
import torch
import torch.nn as nn
from copy import deepcopy
class EWCRegularizer:
"""Elastic Weight Consolidation for continual learning."""
def __init__(self, model, dataloader, device, n_samples=200):
self.params = {
n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad
}
self.fisher = self._compute_fisher(model, dataloader, device, n_samples)
def _compute_fisher(self, model, dataloader, device, n_samples):
"""Estimate diagonal Fisher Information Matrix."""
fisher = {
n: torch.zeros_like(p)
for n, p in model.named_parameters()
if p.requires_grad
}
model.eval()
count = 0
for batch in dataloader:
if count >= n_samples:
break
input_ids = batch["input_ids"].to(device)
outputs = model(input_ids=input_ids, labels=input_ids)
loss = outputs.loss
loss.backward()
for n, p in model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.detach() ** 2
model.zero_grad()
count += input_ids.size(0)
# Normalize
for n in fisher:
fisher[n] /= count
return fisher
def penalty(self, model, lambda_ewc=1000):
"""Compute EWC regularization penalty."""
loss = 0
for n, p in model.named_parameters():
if n in self.fisher:
loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
return (lambda_ewc / 2) * loss
# Usage in training loop:
# ewc = EWCRegularizer(model, general_dataloader, device)
# for batch in domain_dataloader:
# loss = model(**batch).loss + ewc.penalty(model)
Computing the full Fisher Information Matrix for a billion-parameter model is prohibitively expensive. In practice, EWC uses a diagonal approximation (only the diagonal of the Fisher matrix), which can be computed efficiently with a single pass over a small dataset. The lambda_ewc hyperparameter typically ranges from 100 to 10000 and requires tuning. Too low and forgetting persists; too high and the model cannot adapt to the new domain.
6. Progressive Training and Curriculum Approaches
Progressive training structures the continual learning process as a sequence of carefully designed stages, each building on the previous one. Rather than training on all domain data at once, you create a curriculum that gradually shifts the distribution from general to domain-specific.
6.1 Multi-Stage Domain Adaptation Pipeline
6.2 Curriculum Design Principles
Effective curriculum design for continual learning follows several principles. Start with data that is closest to the original pre-training distribution and gradually shift toward the target domain. Within the domain data, present easier examples first (shorter texts, simpler concepts) and progress to harder examples. This gradual transition gives the model time to adjust its internal representations without abrupt distribution shifts.
| Strategy | Description | When to Use |
|---|---|---|
| Data mixing schedule | Start with 50% general / 50% domain, end with 10% / 90% | Large domain shift |
| Learning rate warmup | Start very low (1e-6), gradually increase to target LR | Preventing early destabilization |
| Layer-wise learning rates | Lower LR for early layers, higher for later layers | Preserving foundational features |
| LoRA rank scheduling | Start with low rank (4), increase to final rank (16-32) | Memory-constrained setups |
| Multi-epoch curriculum | Epoch 1: easy domain data; Epoch 2: hard domain data | Diverse difficulty levels in domain |
The most practical approach to continual domain adaptation for most teams is: use LoRA (not full fine-tuning) for each adaptation stage, keep the base model frozen, and maintain a library of composable adapters. This sidesteps catastrophic forgetting entirely because the base weights never change. You can then merge adapters using the techniques from Section 15.2 to combine capabilities, or swap them dynamically at serving time.
Section 15.3 Quiz
1. What is catastrophic forgetting, and why is it particularly problematic for continual pre-training?
Show Answer
2. Why is the learning rate for continual pre-training typically much lower than for initial pre-training?
Show Answer
3. When should you extend the model's vocabulary, and what must you do afterward?
Show Answer
4. How does Elastic Weight Consolidation (EWC) prevent forgetting, and what is its main limitation?
Show Answer
5. Why do many practitioners prefer LoRA-based adaptation over full-parameter continual pre-training for domain adaptation?
Show Answer
Key Takeaways
- Catastrophic forgetting degrades general capabilities when training on domain data. The severity increases with training duration and is worse for full fine-tuning than for parameter-efficient methods.
- Continual pre-training uses next-token prediction on domain corpora to inject new knowledge. Use a lower learning rate (1e-5 to 5e-5) and limit to 1-2 epochs to balance adaptation and retention.
- Replay methods mix general-domain data into the training stream. A 10-20% replay ratio prevents most forgetting for typical domain adaptation scenarios.
- Vocabulary extension improves efficiency for domain-specific terminology, but requires continued training for new embeddings to become effective.
- Elastic Weight Consolidation penalizes changes to important parameters using Fisher Information, providing a principled regularization approach that does not require storing replay data.
- Progressive training structures adaptation as a multi-stage pipeline (CPT, SFT, alignment, evaluation) with curriculum-based data scheduling.
- LoRA-based adaptation is often the best practical choice because frozen base weights eliminate forgetting entirely, and adapters can be composed or swapped for modular domain specialization.
- Always evaluate on both domain and general benchmarks after continual learning to verify that domain gains have not come at the cost of general degradation.