Module 15 · Section 15.3

Continual Learning & Domain Adaptation

Adapting LLMs to new domains over time without forgetting their general capabilities
★ Big Picture

LLMs are expensive to pretrain and quickly become outdated. New knowledge emerges, domains evolve, and organizational needs change. Continual learning addresses how to update a model with new information without expensive retraining from scratch, while avoiding "catastrophic forgetting" of previously learned capabilities. This section covers continual pre-training on domain corpora, vocabulary extension for specialized terminology, replay methods, regularization techniques like Elastic Weight Consolidation, and progressive training strategies. These methods are essential for maintaining production LLMs that must adapt to changing requirements over time.

1. The Catastrophic Forgetting Problem

When you fine-tune or continue training an LLM on new domain-specific data, the model rapidly adapts to the new distribution but simultaneously degrades on its original capabilities. This phenomenon, called catastrophic forgetting, occurs because gradient updates that optimize for new data push the weights away from the regions that encode prior knowledge. The more you train on new data, the more aggressively the model forgets.

This is not just a theoretical concern. In practice, a model that undergoes continual pre-training on medical literature may lose its ability to write code, answer general knowledge questions, or follow instructions properly. The challenge is to absorb new domain knowledge while preserving the broad capabilities that make the model useful.

Catastrophic Forgetting During Domain Adaptation Training Steps on Domain Data Performance Domain General Sweet spot Under-adapted Over-adapted (forgotten)
Figure 1: As domain performance improves, general capabilities degrade. The goal is to find (and maintain) the sweet spot.
◆ Key Insight

Catastrophic forgetting is more severe in full fine-tuning than in parameter-efficient methods. LoRA naturally mitigates forgetting because the base model weights remain frozen, and the low-rank adapter can only make limited modifications. This is one reason why LoRA-based continual learning is increasingly preferred over full-parameter continual pre-training for domain adaptation.

2. Continual Pre-Training

Continual pre-training (CPT) extends the standard pre-training objective (next-token prediction) on domain-specific corpora. Unlike instruction fine-tuning, which teaches the model to follow a new format, CPT injects new factual knowledge and domain vocabulary into the model's weights. This is the primary technique for creating domain-specific foundation models (for example, a medical LLM or a financial LLM).

2.1 Data Preparation for CPT

from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import random

def prepare_cpt_dataset(
    domain_data_path: str,
    general_data_path: str,
    replay_ratio: float = 0.1,
    tokenizer_id: str = "meta-llama/Meta-Llama-3-8B",
    max_seq_length: int = 4096,
):
    """Prepare CPT dataset with replay data to reduce forgetting."""
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    # Load domain-specific data
    domain_ds = load_dataset("text", data_files=domain_data_path, split="train")
    print(f"Domain data: {len(domain_ds)} documents")

    # Load general replay data (subset of original pretraining mix)
    general_ds = load_dataset("text", data_files=general_data_path, split="train")

    # Sample replay data proportionally
    n_replay = int(len(domain_ds) * replay_ratio)
    general_ds = general_ds.shuffle(seed=42).select(range(min(n_replay, len(general_ds))))
    print(f"Replay data: {len(general_ds)} documents ({replay_ratio*100:.0f}%)")

    # Combine and shuffle
    combined = concatenate_datasets([domain_ds, general_ds])
    combined = combined.shuffle(seed=42)

    # Tokenize
    def tokenize(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_seq_length,
            padding=False,
        )

    tokenized = combined.map(tokenize, batched=True, remove_columns=["text"])
    return tokenized

2.2 Training Configuration for CPT

Continual pre-training requires careful hyperparameter selection. The learning rate should be significantly lower than initial pre-training (typically 1e-5 to 5e-5, compared to 1e-4 to 3e-4 for original pre-training) to avoid destabilizing the pretrained weights. Training for too many epochs on domain data accelerates forgetting, so most practitioners use 1-2 passes over the domain corpus.

from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
import torch

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token

# CPT-specific training arguments
training_args = TrainingArguments(
    output_dir="./cpt-medical-llama",
    num_train_epochs=1,                # Typically 1-2 epochs for CPT
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,    # Effective batch size = 32
    learning_rate=2e-5,               # Much lower than pre-training
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    weight_decay=0.1,
    bf16=True,
    logging_steps=50,
    save_strategy="steps",
    save_steps=500,
    max_grad_norm=1.0,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False   # Causal LM (next-token prediction)
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cpt_dataset,
    data_collator=data_collator,
)

trainer.train()
ⓘ Note

After continual pre-training, you typically need a second stage of instruction fine-tuning (SFT) to restore the model's ability to follow instructions. CPT optimizes for next-token prediction on raw text, which can degrade the model's chat/instruction-following behavior. The standard pipeline is: base model → CPT on domain data → SFT on domain-specific instructions.

3. Vocabulary Extension

Domain-specific corpora often contain terminology that the base tokenizer splits into many sub-word tokens. Medical terms like "electroencephalography" might become 5+ tokens, wasting context window capacity and reducing the model's ability to learn associations involving these terms. Vocabulary extension adds new tokens for frequently occurring domain terms, improving both efficiency and representation quality.

3.1 When to Extend Vocabulary

SignalActionExample
Common domain terms split into 4+ tokensAdd as single token"electroencephalography" → 1 token
Non-Latin scripts or specialized notationAdd character/symbol tokensChemical formulas, musical notation
High-frequency domain abbreviationsAdd as tokens"COPD", "MRI", "ICU"
Terms already tokenize efficiently (1-2 tokens)Do not extend"cancer", "patient", "diagnosis"
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

def extend_vocabulary(
    model_id: str,
    new_tokens: list[str],
    init_strategy: str = "mean",
):
    """Extend model vocabulary with domain-specific tokens."""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16
    )

    # Check current tokenization
    for token in new_tokens[:5]:
        ids = tokenizer.encode(token, add_special_tokens=False)
        print(f"  '{token}' -> {len(ids)} tokens: {ids}")

    # Add new tokens to tokenizer
    num_added = tokenizer.add_tokens(new_tokens)
    print(f"Added {num_added} new tokens to vocabulary")

    # Resize model embeddings
    model.resize_token_embeddings(len(tokenizer))

    # Initialize new token embeddings
    if init_strategy == "mean":
        # Initialize as mean of existing embeddings (stable default)
        with torch.no_grad():
            embeddings = model.get_input_embeddings()
            existing_mean = embeddings.weight[:-num_added].mean(dim=0)
            for i in range(num_added):
                embeddings.weight[-num_added + i] = existing_mean

    elif init_strategy == "subword":
        # Initialize as mean of subword embeddings (better for compounds)
        with torch.no_grad():
            embeddings = model.get_input_embeddings()
            for i, token in enumerate(new_tokens):
                # Get original subword token IDs
                old_ids = tokenizer.encode(token, add_special_tokens=False)
                # Skip the newly added token ID
                old_ids = [t for t in old_ids if t < len(tokenizer) - num_added]
                if old_ids:
                    subword_mean = embeddings.weight[old_ids].mean(dim=0)
                    embeddings.weight[-num_added + i] = subword_mean

    print(f"Vocabulary size: {len(tokenizer)}")
    return model, tokenizer

# Example: extend for medical domain
medical_tokens = [
    "electroencephalography", "immunohistochemistry",
    "pharmacokinetics", "bronchoscopy",
    "echocardiogram", "thrombocytopenia",
]
model, tokenizer = extend_vocabulary(
    "meta-llama/Meta-Llama-3-8B",
    medical_tokens,
    init_strategy="subword",
)
⚠ Warning

After extending the vocabulary, you must continue training the model for the new embeddings to become useful. The initialized embeddings (whether from mean or subword strategies) are only approximations. Without further training, the new tokens will produce poor-quality outputs. Plan for at least several thousand gradient steps on data containing the new tokens before the model uses them effectively.

4. Replay Methods

Replay is the most straightforward defense against catastrophic forgetting: mix a portion of the original training data (or data from the same distribution) into the domain-specific training stream. By continuously exposing the model to general-domain examples, you anchor the weights near their original values for non-domain capabilities.

4.1 Replay Strategies

Replay Strategies for Continual Learning Uniform Mixing Fixed ratio throughout (e.g., 90% domain, 10% general) Simple, effective default Curriculum Replay Start with more general, gradually shift to domain Smoother adaptation Generative Replay Model generates its own replay data on the fly No need to store old data Recommended Replay Ratios Light adaptation (format/style): 5% replay | Domain knowledge: 10-20% replay Heavy domain shift (new language): 30-50% replay
Figure 2: Different replay strategies trade off forgetting prevention against domain adaptation speed.
from torch.utils.data import Dataset, DataLoader
import random

class ReplayDataset(Dataset):
    """Dataset that mixes domain data with replay data."""

    def __init__(self, domain_dataset, replay_dataset, replay_ratio=0.1):
        self.domain = domain_dataset
        self.replay = replay_dataset
        self.replay_ratio = replay_ratio

        # Calculate sizes for interleaving
        self.total_size = len(domain_dataset)
        self.n_replay_per_epoch = int(self.total_size * replay_ratio)

        # Pre-sample replay indices for this epoch
        self._resample_replay()

    def _resample_replay(self):
        """Resample replay indices (call at each epoch start)."""
        replay_indices = random.sample(
            range(len(self.replay)),
            min(self.n_replay_per_epoch, len(self.replay))
        )
        # Interleave: every ~(1/ratio) steps, insert a replay sample
        self.schedule = []
        replay_iter = iter(replay_indices)
        interval = int(1 / self.replay_ratio) if self.replay_ratio > 0 else 999999
        for i in range(self.total_size):
            self.schedule.append(("domain", i))
            if (i + 1) % interval == 0:
                try:
                    self.schedule.append(("replay", next(replay_iter)))
                except StopIteration:
                    pass

    def __len__(self):
        return len(self.schedule)

    def __getitem__(self, idx):
        source, data_idx = self.schedule[idx]
        if source == "domain":
            return self.domain[data_idx]
        else:
            return self.replay[data_idx]

5. Elastic Weight Consolidation (EWC)

Elastic Weight Consolidation (Kirkpatrick et al., 2017) adds a regularization term to the loss function that penalizes changes to parameters that were important for previous tasks. It estimates each parameter's "importance" using the Fisher Information Matrix, which measures how much the loss changes when a parameter is perturbed. Important parameters get a stronger regularization penalty, anchoring them near their original values while allowing less important parameters to adapt freely.

The EWC loss adds a quadratic penalty term:

Ltotal = Ltask + (λ / 2) · Σi Fi · (θi - θi*)2

where Fi is the Fisher information for parameter i, θi* is the original parameter value, and λ controls the regularization strength.

import torch
import torch.nn as nn
from copy import deepcopy

class EWCRegularizer:
    """Elastic Weight Consolidation for continual learning."""

    def __init__(self, model, dataloader, device, n_samples=200):
        self.params = {
            n: p.clone().detach()
            for n, p in model.named_parameters()
            if p.requires_grad
        }
        self.fisher = self._compute_fisher(model, dataloader, device, n_samples)

    def _compute_fisher(self, model, dataloader, device, n_samples):
        """Estimate diagonal Fisher Information Matrix."""
        fisher = {
            n: torch.zeros_like(p)
            for n, p in model.named_parameters()
            if p.requires_grad
        }

        model.eval()
        count = 0
        for batch in dataloader:
            if count >= n_samples:
                break
            input_ids = batch["input_ids"].to(device)
            outputs = model(input_ids=input_ids, labels=input_ids)
            loss = outputs.loss
            loss.backward()

            for n, p in model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher[n] += p.grad.detach() ** 2

            model.zero_grad()
            count += input_ids.size(0)

        # Normalize
        for n in fisher:
            fisher[n] /= count

        return fisher

    def penalty(self, model, lambda_ewc=1000):
        """Compute EWC regularization penalty."""
        loss = 0
        for n, p in model.named_parameters():
            if n in self.fisher:
                loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
        return (lambda_ewc / 2) * loss

# Usage in training loop:
# ewc = EWCRegularizer(model, general_dataloader, device)
# for batch in domain_dataloader:
#     loss = model(**batch).loss + ewc.penalty(model)
ⓘ Note

Computing the full Fisher Information Matrix for a billion-parameter model is prohibitively expensive. In practice, EWC uses a diagonal approximation (only the diagonal of the Fisher matrix), which can be computed efficiently with a single pass over a small dataset. The lambda_ewc hyperparameter typically ranges from 100 to 10000 and requires tuning. Too low and forgetting persists; too high and the model cannot adapt to the new domain.

6. Progressive Training and Curriculum Approaches

Progressive training structures the continual learning process as a sequence of carefully designed stages, each building on the previous one. Rather than training on all domain data at once, you create a curriculum that gradually shifts the distribution from general to domain-specific.

6.1 Multi-Stage Domain Adaptation Pipeline

Progressive Domain Adaptation Pipeline Stage 1: CPT Domain corpus + 20% replay Stage 2: SFT Domain instructions + general instructions Stage 3: Align Domain DPO/RLHF + safety alignment Stage 4: Eval Domain benchmarks + general benchmarks Each stage uses LoRA adapters that can be merged sequentially or kept separate for modularity
Figure 3: A multi-stage pipeline progressively adapts the model from general to domain-specific capabilities.

6.2 Curriculum Design Principles

Effective curriculum design for continual learning follows several principles. Start with data that is closest to the original pre-training distribution and gradually shift toward the target domain. Within the domain data, present easier examples first (shorter texts, simpler concepts) and progress to harder examples. This gradual transition gives the model time to adjust its internal representations without abrupt distribution shifts.

StrategyDescriptionWhen to Use
Data mixing scheduleStart with 50% general / 50% domain, end with 10% / 90%Large domain shift
Learning rate warmupStart very low (1e-6), gradually increase to target LRPreventing early destabilization
Layer-wise learning ratesLower LR for early layers, higher for later layersPreserving foundational features
LoRA rank schedulingStart with low rank (4), increase to final rank (16-32)Memory-constrained setups
Multi-epoch curriculumEpoch 1: easy domain data; Epoch 2: hard domain dataDiverse difficulty levels in domain
◆ Key Insight

The most practical approach to continual domain adaptation for most teams is: use LoRA (not full fine-tuning) for each adaptation stage, keep the base model frozen, and maintain a library of composable adapters. This sidesteps catastrophic forgetting entirely because the base weights never change. You can then merge adapters using the techniques from Section 15.2 to combine capabilities, or swap them dynamically at serving time.

Section 15.3 Quiz

1. What is catastrophic forgetting, and why is it particularly problematic for continual pre-training?

Show Answer
Catastrophic forgetting occurs when gradient updates that optimize for new data push model weights away from regions that encode prior knowledge, causing the model to lose previously learned capabilities. It is particularly problematic for continual pre-training because CPT modifies all model parameters to absorb new domain knowledge, and the extended training on domain-specific data can substantially degrade general capabilities like instruction following, coding, and general knowledge.

2. Why is the learning rate for continual pre-training typically much lower than for initial pre-training?

Show Answer
A lower learning rate (typically 1e-5 to 5e-5, versus 1e-4 to 3e-4 for initial pre-training) makes smaller parameter updates per step, reducing the risk of destabilizing the pretrained representations. Large learning rates cause aggressive weight changes that quickly overwrite previously learned features. The lower rate allows the model to gradually incorporate new domain knowledge while preserving its existing capabilities.

3. When should you extend the model's vocabulary, and what must you do afterward?

Show Answer
Extend the vocabulary when domain-specific terms are consistently tokenized into 4+ sub-word tokens, when specialized notation or non-Latin scripts are poorly represented, or when high-frequency domain abbreviations would benefit from single-token representation. After extending the vocabulary, you must continue training the model for the new token embeddings to become useful. The initialized embeddings are only approximations, and the model needs gradient updates on data containing the new tokens to learn proper representations.

4. How does Elastic Weight Consolidation (EWC) prevent forgetting, and what is its main limitation?

Show Answer
EWC adds a regularization penalty that discourages changes to parameters that were important for previous tasks. It estimates importance using the diagonal of the Fisher Information Matrix. Parameters with high Fisher information (meaning small changes to them cause large changes in loss) receive strong regularization. The main limitation is that the diagonal Fisher approximation is crude for billion-parameter models, the lambda hyperparameter requires careful tuning, and computing even the diagonal Fisher requires a pass over representative data from previous tasks.

5. Why do many practitioners prefer LoRA-based adaptation over full-parameter continual pre-training for domain adaptation?

Show Answer
LoRA-based adaptation naturally mitigates catastrophic forgetting because the base model weights remain completely frozen. Only the small adapter weights are trained, which limits the model's ability to stray from its pretrained behavior. Additionally, LoRA adapters can be composed, swapped, or merged, providing a modular approach to domain adaptation. This avoids the need for replay data, EWC regularization, or other forgetting mitigation strategies, while requiring far less compute and memory than full-parameter continual pre-training.

Key Takeaways