Module 13 · Section 13.7

Adapting Models for Long Text

Extending context windows with RoPE scaling and position interpolation, continued pre-training for long contexts, chunking strategies, and the lost-in-the-middle phenomenon
★ Big Picture

Most real-world documents are longer than models were trained to handle. Legal contracts, research papers, codebases, and book manuscripts routinely exceed the 4K or 8K context windows that many models were originally trained with. Simply passing a longer sequence to a model trained on shorter sequences causes severe quality degradation because the positional encodings extrapolate into regions the model has never seen. This section covers the techniques for extending a model's effective context length: mathematical adjustments to positional encodings, continued pre-training on long documents, and practical chunking strategies for when extension is not enough.

1. The Long Context Challenge

Transformer models encode position information through positional embeddings or positional encodings. When a model trained with a maximum sequence length of 4,096 tokens receives a sequence of 8,192 tokens, the positions beyond 4,096 are "out of distribution." The model has never learned what those position values mean, leading to degraded attention patterns and poor generation quality.

1.1 Why Models Fail on Long Sequences

The failure mode depends on the type of positional encoding. Models using absolute positional embeddings (original BERT, GPT-2) have a hard limit: positions beyond the embedding table size simply do not exist. Models using Rotary Position Embeddings (RoPE), which are standard in modern LLMs like Llama, Mistral, and Qwen, can technically process longer sequences, but the rotation angles for unseen positions are extrapolated, causing attention scores to become increasingly noisy.

Model Quality vs. Sequence Length (Without Context Extension) Sequence Length (tokens) Perplexity (lower is better) 1K 2K 4K 8K 16K Training window (4K) Out-of-distribution Perplexity explodes With context extension
Figure 13.15: Without context extension, perplexity degrades sharply beyond the training window. Context extension techniques maintain quality at longer lengths.

2. Context Extension Techniques

Several techniques have been developed to extend a model's effective context length without retraining from scratch. These techniques modify the positional encoding scheme so that longer sequences map to position values the model has already learned to handle.

2.1 RoPE Scaling (Linear Interpolation)

The simplest context extension method is linear scaling, also called position interpolation. Instead of using raw position indices (0, 1, 2, ..., 8191) for an 8K sequence, you scale them down to fit within the original training range: (0, 0.5, 1, 1.5, ..., 4095.5). This ensures all position values fall within the range the model was trained on.

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

# Method 1: Linear scaling (Position Interpolation)
# Extend a 4K context model to handle 16K sequences
model_name = "meta-llama/Llama-3.1-8B"

config = AutoConfig.from_pretrained(model_name)

# Set the RoPE scaling configuration
config.rope_scaling = {
    "type": "linear",
    "factor": 4.0,  # Extend 4x: 4K -> 16K
}
# Update max position embeddings to match
config.max_position_embeddings = 16384  # 4096 * 4

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype="auto",
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Max positions: {model.config.max_position_embeddings}")
print(f"RoPE scaling: {model.config.rope_scaling}")
Max positions: 16384 RoPE scaling: {'type': 'linear', 'factor': 4.0}

2.2 Dynamic NTK Scaling

Dynamic NTK (Neural Tangent Kernel) scaling is a more sophisticated approach that adjusts the frequency basis of RoPE dynamically based on the actual sequence length. Instead of uniformly scaling all frequencies, it applies stronger scaling to high-frequency components (which encode fine-grained position distinctions) while leaving low-frequency components (which encode coarse position information) largely unchanged. This preserves local attention patterns better than uniform scaling.

# Method 2: Dynamic NTK scaling
config = AutoConfig.from_pretrained(model_name)
config.rope_scaling = {
    "type": "dynamic",
    "factor": 4.0,  # Extend 4x
}
config.max_position_embeddings = 16384

# Dynamic NTK computes the scaling based on actual sequence length
# at inference time, so it adapts to varying input lengths

model_dynamic = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype="auto",
    device_map="auto",
)

print(f"RoPE scaling: {model_dynamic.config.rope_scaling}")

2.3 YaRN (Yet another RoPE extensioN)

YaRN combines NTK-aware scaling with attention temperature adjustment and a ramp function that smoothly transitions between unscaled low frequencies and scaled high frequencies. It generally produces the best results among the scaling-only methods (no fine-tuning required) and is the default approach used by many model providers for context extension.

# Method 3: YaRN scaling
config = AutoConfig.from_pretrained(model_name)
config.rope_scaling = {
    "type": "yarn",
    "factor": 4.0,
    "original_max_position_embeddings": 4096,
    # YaRN-specific parameters
    "attention_factor": None,  # Auto-computed
    "beta_fast": 32,           # High-frequency boundary
    "beta_slow": 1,            # Low-frequency boundary
}
config.max_position_embeddings = 16384

model_yarn = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype="auto",
    device_map="auto",
)
MethodHow It WorksQualityFine-Tuning Needed
Linear scalingUniformly compress positions into training rangeGood for 2x to 4x extensionRecommended (few hundred steps)
Dynamic NTKScale high frequencies more than low frequenciesBetter than linear at higher ratiosOptional, improves with tuning
YaRNNTK + attention temperature + frequency rampBest scaling-only methodOptional, best results with tuning
LongRoPELearned per-dimension scaling factorsState-of-the-art qualityRequired (progressive training)
🔑 Key Insight

Scaling alone gives you 2x to 4x; beyond that, you need fine-tuning. Position scaling methods work well for modest context extensions (4K to 16K, or 8K to 32K). For larger extensions (4K to 128K), you typically need to combine scaling with continued pre-training on long documents. The scaling fixes the positional encoding issue, and the fine-tuning teaches the model to actually attend across long distances effectively.

3. Continued Pre-Training for Long Context

For significant context extension (beyond 4x the original training length), the most reliable approach is continued pre-training on long documents with the new positional encoding configuration. This teaches the model to form meaningful attention patterns across the extended range.

3.1 LongLoRA: Efficient Long-Context Fine-Tuning

LongLoRA is a technique that makes long-context fine-tuning affordable by combining LoRA (parameter-efficient fine-tuning) with shifted sparse attention during training. During training, attention is computed within local windows with a shift pattern that creates information flow across the full sequence. At inference time, standard full attention is used. This reduces training memory from quadratic to linear in sequence length.

# LongLoRA-style training setup
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig

model_name = "meta-llama/Llama-3.1-8B"

# Step 1: Configure extended context
config = AutoConfig.from_pretrained(model_name)
config.rope_scaling = {"type": "linear", "factor": 8.0}
config.max_position_embeddings = 32768  # 4K -> 32K

# Step 2: Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)

# Step 3: Apply LoRA for parameter-efficient training
lora_config = LoraConfig(
    r=32,                       # Rank (higher for long context)
    lora_alpha=64,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # Attention layers
        "gate_proj", "up_proj", "down_proj",       # MLP layers
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
print(f"Trainable parameters: {model.print_trainable_parameters()}")

# Step 4: Prepare long-document training data
# Use documents that are genuinely long (books, papers, code repos)
# Each example should fill or exceed the target context length

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Step 5: Configure training for long sequences
sft_config = SFTConfig(
    output_dir="./checkpoints/long-context-llama",
    max_seq_length=32768,             # Full target length
    per_device_train_batch_size=1,    # Long sequences need BS=1
    gradient_accumulation_steps=16,   # Effective batch = 16
    num_train_epochs=1,               # 1 epoch is usually enough
    learning_rate=2e-5,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=5,
    save_steps=50,
    packing=True,
)

# trainer = SFTTrainer(
#     model=model,
#     args=sft_config,
#     train_dataset=long_document_dataset,
#     processing_class=tokenizer,
# )
# trainer.train()
⚠ Warning

Memory requirements scale quadratically without Flash Attention. Standard attention requires O(n^2) memory where n is the sequence length. A 32K sequence uses 16x the memory of an 8K sequence. Flash Attention 2 reduces this to O(n) by computing attention in blocks without materializing the full attention matrix. Always use Flash Attention when training with long sequences; without it, even an 80GB A100 cannot handle sequences beyond approximately 16K tokens for a 7B model.

4. Chunking Strategies

When documents exceed even the extended context window, or when context extension is not practical, you need chunking strategies to process long texts in pieces. The quality of your chunking strategy has a significant impact on downstream performance, particularly for retrieval and summarization tasks.

4.1 Common Chunking Approaches

Chunking Strategies Comparison Fixed-Size (Naive) 512 512 512 Can split mid-sentence Overlapping Overlap preserves context at boundaries Semantic §1 §2 (longer) §3 Splits at natural boundaries Fixed-Size Overlapping Semantic Speed: Fastest Fast Slower (needs NLP) Quality: Lowest Good Best Dedup: None needed Some duplication None needed Chunk size: Uniform Uniform + overlap Variable Best for: Prototyping General use Production RAG
Figure 13.16: Comparison of chunking strategies from simple fixed-size to semantic boundary detection
from typing import List

def chunk_with_overlap(
    text: str,
    chunk_size: int = 512,
    overlap: int = 64,
    tokenizer=None,
) -> List[dict]:
    """Chunk text with token-level overlap."""
    if tokenizer is None:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    tokens = tokenizer.encode(text, add_special_tokens=False)
    chunks = []
    start = 0

    while start < len(tokens):
        end = min(start + chunk_size, len(tokens))
        chunk_tokens = tokens[start:end]
        chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)

        chunks.append({
            "text": chunk_text,
            "start_token": start,
            "end_token": end,
            "num_tokens": len(chunk_tokens),
        })

        # Move forward by (chunk_size - overlap)
        start += chunk_size - overlap

        # Stop if we have reached the end
        if end >= len(tokens):
            break

    return chunks

def semantic_chunk(
    text: str,
    max_chunk_tokens: int = 512,
    tokenizer=None,
) -> List[dict]:
    """Split text at semantic boundaries (paragraphs, sections)."""
    import re

    # Split on paragraph boundaries (double newlines) and section headers
    segments = re.split(r'\n\n+|\n(?=#)', text)
    segments = [s.strip() for s in segments if s.strip()]

    if tokenizer is None:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    chunks = []
    current_segments = []
    current_tokens = 0

    for segment in segments:
        seg_tokens = len(tokenizer.encode(segment, add_special_tokens=False))

        if current_tokens + seg_tokens > max_chunk_tokens and current_segments:
            # Save current chunk and start a new one
            chunks.append({
                "text": "\n\n".join(current_segments),
                "num_tokens": current_tokens,
                "num_segments": len(current_segments),
            })
            current_segments = []
            current_tokens = 0

        current_segments.append(segment)
        current_tokens += seg_tokens

    # Save the last chunk
    if current_segments:
        chunks.append({
            "text": "\n\n".join(current_segments),
            "num_tokens": current_tokens,
            "num_segments": len(current_segments),
        })

    return chunks

# Example
sample_text = "First paragraph about topic A. " * 50 + "\n\n" + \
              "Second paragraph about topic B. " * 30 + "\n\n" + \
              "Third paragraph wrapping up. " * 20

overlap_chunks = chunk_with_overlap(sample_text, chunk_size=128, overlap=32)
semantic_chunks = semantic_chunk(sample_text, max_chunk_tokens=256)

print(f"Overlap chunking: {len(overlap_chunks)} chunks")
print(f"Semantic chunking: {len(semantic_chunks)} chunks")
📝 Note

Overlap size matters. Too little overlap (less than 10% of chunk size) risks losing context at boundaries. Too much overlap (more than 30%) creates excessive duplication and increases storage and compute costs. A good default is 10% to 20% overlap. For retrieval tasks, err toward more overlap; for summarization, less overlap usually suffices since each chunk is processed independently.

5. The Lost-in-the-Middle Phenomenon

Even models that can technically process long contexts do not attend to all positions equally. Research has shown that language models exhibit a "lost-in-the-middle" effect: they attend strongly to information at the beginning and end of the context window but struggle to recall or use information in the middle. This has significant implications for how you structure long prompts and which context extension strategies you choose.

Lost-in-the-Middle: Attention Distribution Over Position Position in Context Window Retrieval Accuracy Start Middle End "Lost" zone High recall High recall
Figure 13.17: Models recall information better from the start and end of the context, with a U-shaped accuracy curve
# Practical strategies for mitigating lost-in-the-middle
from typing import List, Dict

def reorder_context_for_retrieval(
    query: str,
    retrieved_passages: List[Dict],
    strategy: str = "important_first_last"
) -> List[Dict]:
    """Reorder passages to mitigate the lost-in-the-middle effect."""

    if strategy == "important_first_last":
        # Place the most relevant passages at the start and end
        # Less relevant passages go in the middle
        sorted_passages = sorted(
            retrieved_passages,
            key=lambda p: p["relevance_score"],
            reverse=True
        )

        n = len(sorted_passages)
        reordered = [None] * n

        # Alternate between start and end positions
        left, right = 0, n - 1
        for i, passage in enumerate(sorted_passages):
            if i % 2 == 0:
                reordered[left] = passage
                left += 1
            else:
                reordered[right] = passage
                right -= 1

        return reordered

    elif strategy == "reverse_rank":
        # Put least relevant first, most relevant last
        # (recency bias helps with last items)
        return sorted(
            retrieved_passages,
            key=lambda p: p["relevance_score"]
        )

    return retrieved_passages

# Example: 10 passages ranked by relevance
passages = [
    {"text": f"Passage {i}", "relevance_score": 1.0 - i * 0.1}
    for i in range(10)
]

reordered = reorder_context_for_retrieval("query", passages)
positions = [(p["text"], f"score={p['relevance_score']:.1f}") for p in reordered]
for i, (text, score) in enumerate(positions):
    position_label = "START" if i < 2 else "END" if i >= 8 else "middle"
    print(f"  Position {i:2d} [{position_label:6s}]: {text} ({score})")
Position 0 [START ]: Passage 0 (score=1.0) Position 1 [START ]: Passage 2 (score=0.8) Position 2 [middle]: Passage 4 (score=0.6) Position 3 [middle]: Passage 6 (score=0.4) Position 4 [middle]: Passage 8 (score=0.2) Position 5 [middle]: Passage 9 (score=0.1) Position 6 [middle]: Passage 7 (score=0.3) Position 7 [middle]: Passage 5 (score=0.5) Position 8 [END ]: Passage 3 (score=0.7) Position 9 [END ]: Passage 1 (score=0.9)
🔑 Key Insight

Structure your prompts with the U-shape in mind. Place the most critical information (key instructions, the most relevant retrieved passages, essential context) at the very beginning and the very end of your prompt. Less critical supporting information can go in the middle. This simple reordering can improve retrieval accuracy by 10% to 20% on long-context tasks without any model changes.

Section 13.7 Quiz

1. Why does a model trained with a 4K context window produce poor results on an 8K sequence, even though it can technically process the tokens?
Show Answer
The model's positional encodings (typically RoPE) encode position information using mathematical functions that the model learned to interpret during pre-training. Positions beyond 4,096 produce encoding values that the model has never seen during training, making them "out of distribution." The attention mechanism relies on these position encodings to compute relative distances between tokens, so out-of-range positions produce noisy, meaningless attention patterns that degrade output quality.
2. What is the key difference between linear RoPE scaling and Dynamic NTK scaling?
Show Answer
Linear scaling uniformly compresses all position indices by the scaling factor (e.g., dividing all positions by 4 to fit 16K into a 4K range). This treats all frequency components equally. Dynamic NTK scaling applies different scaling factors to different frequency components: high-frequency components (which encode fine-grained position distinctions) receive stronger scaling, while low-frequency components (which encode coarse position information) are left largely unchanged. This preserves local attention patterns better, resulting in higher quality at larger scaling factors.
3. Why is Flash Attention essential for long-context training?
Show Answer
Standard self-attention requires O(n^2) memory to materialize the full attention matrix, where n is the sequence length. Doubling the sequence length quadruples memory usage. For a 32K sequence, this would require approximately 16x the memory of an 8K sequence, exceeding the capacity of even 80GB GPUs for 7B+ models. Flash Attention computes attention in blocks using a tiling algorithm that never materializes the full n x n matrix, reducing memory to O(n). This makes long-context training feasible on available hardware.
4. What is the "lost-in-the-middle" phenomenon, and how can you mitigate it?
Show Answer
The lost-in-the-middle phenomenon is the observation that language models recall information much better from the beginning and end of the context window than from the middle, producing a U-shaped accuracy curve. Mitigation strategies include: (1) placing the most important information at the start and end of the prompt, (2) reordering retrieved passages so the most relevant appear at boundary positions, (3) using hierarchical summarization that processes chunks independently before combining, and (4) fine-tuning specifically on tasks that require middle-context retrieval.
5. When should you use chunking strategies instead of context extension?
Show Answer
Use chunking when: (1) the document exceeds even the extended context window (e.g., a 500-page book exceeds any practical context length), (2) you cannot afford the compute cost of long-context fine-tuning, (3) you need to process many documents and long-context inference is too slow, or (4) your task is naturally decomposable into independent chunks (e.g., searching for specific facts rather than reasoning over the entire document). Context extension is better when the task requires global reasoning across the full document, such as summarizing all themes in a long report or detecting contradictions across sections.

Key Takeaways