Module 15 · Section 15.1

Knowledge Distillation for LLMs

Transferring capabilities from large teacher models into small, deployable students
★ Big Picture

Knowledge distillation is the art of making small models behave like large ones. A 70B-parameter teacher model contains vast knowledge but is expensive to serve. By training a smaller student model to mimic the teacher's output distribution (not just its final answers), the student can inherit much of the teacher's capability at a fraction of the inference cost. This technique has produced some of the most remarkable results in the LLM space: Microsoft's Phi-3 models distilled from GPT-4 demonstrate that a 3.8B model can match models 10x its size. DeepSeek distilled its R1 reasoning model into compact variants that retain strong chain-of-thought abilities.

1. Classical Distillation Framework

1.1 The Teacher-Student Paradigm

Knowledge distillation (Hinton et al., 2015) trains a smaller "student" model to match the output probability distribution of a larger "teacher" model, rather than training the student solely on hard ground-truth labels. The key insight is that the teacher's probability distribution over all possible tokens contains far richer information than a single correct answer. When the teacher assigns 0.6 probability to "happy," 0.2 to "glad," and 0.1 to "joyful," these "soft" probabilities encode semantic relationships that hard labels cannot convey.

Knowledge Distillation: Teacher to Student Teacher (70B) Large, expensive, accurate Student (7B) Small, fast, efficient Soft Targets (T > 1) Training Input Distillation Loss L = α · KL(teacher_soft || student_soft) + (1 - α) · CE(labels, student_hard) Soft target loss (T² scaled) Hard label loss (standard CE)
Figure 1: The student learns from both the teacher's soft probability distribution and the ground-truth hard labels.

1.2 Temperature and Soft Targets

The temperature parameter T controls how "soft" the teacher's output distribution becomes. At T=1 (normal softmax), the teacher's distribution is peaked on the most likely token. As T increases, the distribution becomes smoother, revealing the relative probabilities of less likely tokens. This "dark knowledge" in the non-top predictions encodes the teacher's understanding of semantic similarity and uncertainty.

The softmax with temperature is computed as:

pi = exp(zi / T) / Σj exp(zj / T)

where zi are the logits (pre-softmax values). Common temperature values range from 1.5 to 4.0. Higher temperatures expose more of the teacher's knowledge but also introduce more noise. The distillation loss is scaled by T² to compensate for the gradient magnitude reduction caused by softening.

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    """Combined distillation and task loss for LLM training."""

    def __init__(self, temperature=2.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha   # Weight for distillation vs hard label loss
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits, labels):
        # Soft targets: soften both distributions with temperature
        T = self.temperature
        student_soft = F.log_softmax(student_logits / T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / T, dim=-1)

        # KL divergence loss (scaled by T^2)
        distill_loss = self.kl_loss(student_soft, teacher_soft) * (T ** 2)

        # Hard label cross-entropy loss
        hard_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=-100,
        )

        # Combined loss
        return self.alpha * distill_loss + (1 - self.alpha) * hard_loss
◆ Key Insight

The temperature parameter is critical. At T=1, the teacher's distribution is so peaked that it provides little more information than a hard label. At T=4, the distribution is smooth enough to reveal semantic relationships between tokens. However, too high a temperature washes out the signal entirely. Start with T=2.0 and tune based on validation performance. The T² scaling factor in the loss ensures consistent gradient magnitudes regardless of temperature.

2. White-Box vs. Black-Box Distillation

The distillation approach depends on what access you have to the teacher model. White-box distillation requires access to the teacher's internal logits; black-box distillation works only with the teacher's text outputs.

AspectWhite-Box DistillationBlack-Box Distillation
Teacher accessFull model weights and logitsAPI outputs (text only)
Loss signalKL divergence on full distributionCross-entropy on generated text
Information richnessVery high (full probability distribution)Lower (only top-1 output)
Typical teachersOpen-weight models (Llama, Mistral)API models (GPT-4, Claude)
Quality ceilingHigher (more teacher knowledge transferred)Lower (limited to surface behavior)
ScalabilityLimited by GPU memory for teacherLimited by API cost and rate limits

2.1 White-Box Distillation

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader

def white_box_distillation(
    teacher_model_id: str,
    student_model_id: str,
    train_dataset,
    temperature: float = 2.0,
    alpha: float = 0.5,
    epochs: int = 3,
    lr: float = 2e-5,
):
    """Train student to match teacher logit distribution."""

    # Load teacher (frozen, in eval mode)
    teacher = AutoModelForCausalLM.from_pretrained(
        teacher_model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )
    teacher.eval()
    for param in teacher.parameters():
        param.requires_grad = False

    # Load student (trainable)
    student = AutoModelForCausalLM.from_pretrained(
        student_model_id, torch_dtype=torch.bfloat16, device_map="auto"
    )

    loss_fn = DistillationLoss(temperature=temperature, alpha=alpha)
    optimizer = torch.optim.AdamW(student.parameters(), lr=lr)
    dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    for epoch in range(epochs):
        student.train()
        total_loss = 0
        for batch in dataloader:
            input_ids = batch["input_ids"].to(student.device)
            labels = batch["labels"].to(student.device)

            # Get teacher logits (no gradient)
            with torch.no_grad():
                teacher_out = teacher(input_ids=input_ids)
                teacher_logits = teacher_out.logits

            # Get student logits
            student_out = student(input_ids=input_ids)
            student_logits = student_out.logits

            # Compute combined loss
            loss = loss_fn(student_logits, teacher_logits, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}: avg_loss = {total_loss/len(dataloader):.4f}")

    return student

2.2 Black-Box Distillation

When the teacher is an API model (GPT-4, Claude, Gemini), you cannot access logits. Instead, you generate a large dataset of high-quality (input, output) pairs from the teacher, then fine-tune the student on these pairs using standard supervised training. The quality of the distilled student depends heavily on the diversity and quality of the generated training data.

import asyncio
from openai import AsyncOpenAI
import json

client = AsyncOpenAI()

async def generate_distillation_data(
    prompts: list[str],
    teacher_model: str = "gpt-4o",
    system_prompt: str = "You are a helpful assistant.",
    max_concurrent: int = 10,
) -> list[dict]:
    """Generate training data from API teacher for black-box distillation."""

    semaphore = asyncio.Semaphore(max_concurrent)
    results = []

    async def call_teacher(prompt):
        async with semaphore:
            response = await client.chat.completions.create(
                model=teacher_model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.7,
                max_tokens=2048,
            )
            return {
                "instruction": prompt,
                "response": response.choices[0].message.content,
                "teacher": teacher_model,
            }

    tasks = [call_teacher(p) for p in prompts]
    results = await asyncio.gather(*tasks, return_exceptions=True)

    # Filter out errors
    valid = [r for r in results if isinstance(r, dict)]
    print(f"Generated {len(valid)}/{len(prompts)} training examples")
    return valid

# Generate dataset
# training_data = asyncio.run(generate_distillation_data(prompts))
# Then fine-tune student model on this data using standard SFT
⚠ Warning

Black-box distillation from proprietary API models raises important licensing considerations. Most API providers (OpenAI, Anthropic, Google) have terms of service that restrict using their outputs to train competing models. Always review the provider's usage policy before conducting distillation. Open-weight models (Llama, Mistral, Qwen) generally allow distillation, but check their specific licenses. The Llama 3 license, for example, allows derivative works but has specific restrictions for very large deployments.

3. Case Studies in LLM Distillation

3.1 Orca: Learning from Complex Explanations

Microsoft's Orca (2023) demonstrated that small models can dramatically improve by learning not just the teacher's answers but its reasoning process. Orca trained a 13B student on millions of examples from GPT-4 that included detailed chain-of-thought explanations, step-by-step reasoning, and self-correction. The key innovations were: using system prompts to elicit rich explanations from the teacher, curating diverse and challenging prompts, and training the student on the full reasoning trace rather than just the final answer.

3.2 Phi Series: Textbook-Quality Data

Microsoft's Phi models (Phi-1, Phi-1.5, Phi-2, Phi-3) showed that data quality matters more than model size. Rather than distilling on conversational data, the Phi team used GPT-4 to generate "textbook-quality" synthetic training data: carefully structured explanations, worked examples, and exercises across diverse topics. Phi-3 (3.8B parameters) achieves performance competitive with much larger models on reasoning benchmarks, demonstrating that a small model trained on exceptional data can outperform a larger model trained on mediocre data.

3.3 Distilled DeepSeek-R1

DeepSeek distilled their large R1 reasoning model (671B MoE) into a family of smaller dense models (1.5B, 7B, 8B, 14B, 32B, 70B). The distillation process used 800K samples of the R1 teacher's chain-of-thought reasoning traces. The distilled models retain strong mathematical and coding reasoning abilities, with the 32B distilled variant outperforming many larger models on math benchmarks. This demonstrates that reasoning capabilities, which were previously thought to require enormous scale, can be effectively compressed through distillation.

Distillation Case Study Results Orca (13B) Matched ChatGPT on complex reasoning benchmarks Phi-3 (3.8B) Competitive with Llama-3 8B on MMLU and reasoning tasks DeepSeek-R1 (32B) Strong math reasoning from 671B MoE teacher Common Design Principles 1. Chain-of-thought: train on reasoning traces, not just final answers 2. Data quality: curated, diverse, challenging prompts yield better students 3. Scale of data: hundreds of thousands to millions of teacher examples 4. System prompts: instruct teacher to explain, reason, and show work
Figure 2: Successful distillation projects share common principles: rich reasoning traces, high-quality data, and diverse prompts.
◆ Key Insight

The single most impactful lesson from distillation research is: distill the reasoning process, not just the answer. When a teacher model generates chain-of-thought explanations, step-by-step solutions, and self-corrections, the student learns much more effectively than from answer-only training data. This is why models like Orca and distilled DeepSeek-R1 dramatically outperform naive distillation approaches that only collect the teacher's final outputs.

4. Small-but-Capable Models

Distillation has enabled a new class of small models that achieve remarkable performance relative to their size. These models demonstrate that the right combination of architecture, training data, and distillation can produce efficient models for deployment on edge devices, mobile platforms, or high-throughput serving scenarios.

Model FamilySizesKey TechniqueNotable Capability
Phi (Microsoft)1.3B, 2.7B, 3.8B, 14BTextbook-quality synthetic dataStrong reasoning for size
Gemma (Google)2B, 7B, 9B, 27BDistilled from GeminiMultilingual, coding
SmolLM (HF)135M, 360M, 1.7BCurated web + synthetic dataUltra-small deployment
Qwen2.5 (Alibaba)0.5B, 1.5B, 3B, 7B+Multi-stage distillationMath, code, multilingual
Llama 3.2 (Meta)1B, 3BPruning + distillationOn-device, mobile

5. Practical Distillation Pipeline

Here is a complete pipeline that combines data generation from an API teacher with student training, demonstrating the end-to-end black-box distillation workflow.

from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
import json

# Step 1: Prepare distillation dataset
# (Assume we have generated data from teacher API)
distillation_data = [
    {
        "messages": [
            {"role": "user", "content": "Explain gradient descent."},
            {"role": "assistant", "content": "Gradient descent is an optimization..."},
        ]
    },
    # ... thousands more examples
]

dataset = Dataset.from_list(distillation_data)

# Step 2: Configure student with LoRA (parameter-efficient)
student_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(student_id)

lora_config = LoraConfig(
    r=32,              # Higher rank for distillation
    lora_alpha=64,
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

# Step 3: Train student on teacher-generated data
sft_config = SFTConfig(
    output_dir="./distilled-student",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    max_seq_length=4096,
    logging_steps=10,
    save_strategy="epoch",
)

trainer = SFTTrainer(
    model=student_id,
    args=sft_config,
    train_dataset=dataset,
    peft_config=lora_config,
)

trainer.train()
print("Distillation complete. Evaluate student against teacher.")
ⓘ Note

For distillation, use a higher LoRA rank (32-64) than you would for standard fine-tuning. The student needs more capacity to absorb the teacher's knowledge. Also consider training for more epochs (3-5) with a larger and more diverse dataset. Distillation benefits from data scale more than standard fine-tuning because each example conveys information about the teacher's behavior across many dimensions.

6. Speculative Distillation

Speculative decoding uses a small "draft" model to propose multiple tokens at once, which a larger model then verifies in a single forward pass. Speculative distillation trains the draft model specifically to mimic the larger model's token distribution, improving the acceptance rate (how often the large model agrees with the draft) and thus the overall throughput. This technique turns distillation into a serving-time optimization rather than just a training-time technique.

Speculative Decoding with Distilled Draft Model Draft Model (1B, distilled) t1 t2 t3 t4 Propose 4 tokens (fast) Target Model (70B, full) Verify all at once (1 pass) Result: 3 tokens accepted in 1 forward pass = ~3x speedup
Figure 3: A distilled draft model proposes tokens that the target model verifies in parallel, multiplying throughput.

Section 15.1 Quiz

1. Why do soft targets (teacher probabilities with temperature) provide more useful training signal than hard labels?

Show Answer
Soft targets encode the teacher's uncertainty and the relative similarity between possible outputs. When the teacher assigns 0.6 probability to "happy" and 0.2 to "glad," this reveals that these words are semantically related. Hard labels (one-hot vectors) provide only the correct answer and no information about the relationships between alternatives. This "dark knowledge" in the non-top predictions helps the student learn richer representations.

2. What is the difference between white-box and black-box distillation, and which produces higher-quality students?

Show Answer
White-box distillation has access to the teacher's full logit distribution and uses KL divergence loss. Black-box distillation only has access to the teacher's text outputs and uses standard cross-entropy loss. White-box generally produces higher-quality students because the full probability distribution contains far more information than a single output token. However, black-box is the only option when the teacher is an API model without logit access.

3. What was the key insight from Microsoft's Orca model that improved distillation quality?

Show Answer
Orca demonstrated that training the student on the teacher's full reasoning process (chain-of-thought explanations, step-by-step solutions, self-corrections) is far more effective than training on just the final answers. By using system prompts to elicit detailed explanations from GPT-4, Orca created training data that taught the student how to reason, not just what to answer.

4. Why does the distillation loss include a T-squared scaling factor?

Show Answer
When temperature T is applied to the softmax, it reduces the magnitude of the gradients by a factor of T². Multiplying the KL divergence loss by T² compensates for this reduction, ensuring that the gradient magnitudes are consistent regardless of the temperature value. Without this scaling, higher temperatures would produce vanishingly small gradients and extremely slow training.

5. How does speculative distillation differ from standard distillation in terms of its goal?

Show Answer
Standard distillation aims to create a small model that replaces the large model entirely for inference. Speculative distillation creates a small "draft" model that works alongside the large model during inference, proposing token candidates that the large model verifies in parallel. The goal is not replacement but acceleration: by training the draft model to closely match the target's distribution, more proposed tokens are accepted per verification step, increasing throughput by 2-4x while maintaining the exact output quality of the large model.

Key Takeaways