Most real-world documents are longer than models were trained to handle. Legal contracts, research papers, codebases, and book manuscripts routinely exceed the 4K or 8K context windows that many models were originally trained with. Simply passing a longer sequence to a model trained on shorter sequences causes severe quality degradation because the positional encodings extrapolate into regions the model has never seen. This section covers the techniques for extending a model's effective context length: mathematical adjustments to positional encodings, continued pre-training on long documents, and practical chunking strategies for when extension is not enough.
1. The Long Context Challenge
Transformer models encode position information through positional embeddings or positional encodings. When a model trained with a maximum sequence length of 4,096 tokens receives a sequence of 8,192 tokens, the positions beyond 4,096 are "out of distribution." The model has never learned what those position values mean, leading to degraded attention patterns and poor generation quality.
1.1 Why Models Fail on Long Sequences
The failure mode depends on the type of positional encoding. Models using absolute positional embeddings (original BERT, GPT-2) have a hard limit: positions beyond the embedding table size simply do not exist. Models using Rotary Position Embeddings (RoPE), which are standard in modern LLMs like Llama, Mistral, and Qwen, can technically process longer sequences, but the rotation angles for unseen positions are extrapolated, causing attention scores to become increasingly noisy.
2. Context Extension Techniques
Several techniques have been developed to extend a model's effective context length without retraining from scratch. These techniques modify the positional encoding scheme so that longer sequences map to position values the model has already learned to handle.
2.1 RoPE Scaling (Linear Interpolation)
The simplest context extension method is linear scaling, also called position interpolation. Instead of using raw position indices (0, 1, 2, ..., 8191) for an 8K sequence, you scale them down to fit within the original training range: (0, 0.5, 1, 1.5, ..., 4095.5). This ensures all position values fall within the range the model was trained on.
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
# Method 1: Linear scaling (Position Interpolation)
# Extend a 4K context model to handle 16K sequences
model_name = "meta-llama/Llama-3.1-8B"
config = AutoConfig.from_pretrained(model_name)
# Set the RoPE scaling configuration
config.rope_scaling = {
"type": "linear",
"factor": 4.0, # Extend 4x: 4K -> 16K
}
# Update max position embeddings to match
config.max_position_embeddings = 16384 # 4096 * 4
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype="auto",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Max positions: {model.config.max_position_embeddings}")
print(f"RoPE scaling: {model.config.rope_scaling}")
2.2 Dynamic NTK Scaling
Dynamic NTK (Neural Tangent Kernel) scaling is a more sophisticated approach that adjusts the frequency basis of RoPE dynamically based on the actual sequence length. Instead of uniformly scaling all frequencies, it applies stronger scaling to high-frequency components (which encode fine-grained position distinctions) while leaving low-frequency components (which encode coarse position information) largely unchanged. This preserves local attention patterns better than uniform scaling.
# Method 2: Dynamic NTK scaling
config = AutoConfig.from_pretrained(model_name)
config.rope_scaling = {
"type": "dynamic",
"factor": 4.0, # Extend 4x
}
config.max_position_embeddings = 16384
# Dynamic NTK computes the scaling based on actual sequence length
# at inference time, so it adapts to varying input lengths
model_dynamic = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype="auto",
device_map="auto",
)
print(f"RoPE scaling: {model_dynamic.config.rope_scaling}")
2.3 YaRN (Yet another RoPE extensioN)
YaRN combines NTK-aware scaling with attention temperature adjustment and a ramp function that smoothly transitions between unscaled low frequencies and scaled high frequencies. It generally produces the best results among the scaling-only methods (no fine-tuning required) and is the default approach used by many model providers for context extension.
# Method 3: YaRN scaling
config = AutoConfig.from_pretrained(model_name)
config.rope_scaling = {
"type": "yarn",
"factor": 4.0,
"original_max_position_embeddings": 4096,
# YaRN-specific parameters
"attention_factor": None, # Auto-computed
"beta_fast": 32, # High-frequency boundary
"beta_slow": 1, # Low-frequency boundary
}
config.max_position_embeddings = 16384
model_yarn = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype="auto",
device_map="auto",
)
| Method | How It Works | Quality | Fine-Tuning Needed |
|---|---|---|---|
| Linear scaling | Uniformly compress positions into training range | Good for 2x to 4x extension | Recommended (few hundred steps) |
| Dynamic NTK | Scale high frequencies more than low frequencies | Better than linear at higher ratios | Optional, improves with tuning |
| YaRN | NTK + attention temperature + frequency ramp | Best scaling-only method | Optional, best results with tuning |
| LongRoPE | Learned per-dimension scaling factors | State-of-the-art quality | Required (progressive training) |
Scaling alone gives you 2x to 4x; beyond that, you need fine-tuning. Position scaling methods work well for modest context extensions (4K to 16K, or 8K to 32K). For larger extensions (4K to 128K), you typically need to combine scaling with continued pre-training on long documents. The scaling fixes the positional encoding issue, and the fine-tuning teaches the model to actually attend across long distances effectively.
3. Continued Pre-Training for Long Context
For significant context extension (beyond 4x the original training length), the most reliable approach is continued pre-training on long documents with the new positional encoding configuration. This teaches the model to form meaningful attention patterns across the extended range.
3.1 LongLoRA: Efficient Long-Context Fine-Tuning
LongLoRA is a technique that makes long-context fine-tuning affordable by combining LoRA (parameter-efficient fine-tuning) with shifted sparse attention during training. During training, attention is computed within local windows with a shift pattern that creates information flow across the full sequence. At inference time, standard full attention is used. This reduces training memory from quadratic to linear in sequence length.
# LongLoRA-style training setup
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
model_name = "meta-llama/Llama-3.1-8B"
# Step 1: Configure extended context
config = AutoConfig.from_pretrained(model_name)
config.rope_scaling = {"type": "linear", "factor": 8.0}
config.max_position_embeddings = 32768 # 4K -> 32K
# Step 2: Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
# Step 3: Apply LoRA for parameter-efficient training
lora_config = LoraConfig(
r=32, # Rank (higher for long context)
lora_alpha=64,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj", # Attention layers
"gate_proj", "up_proj", "down_proj", # MLP layers
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
print(f"Trainable parameters: {model.print_trainable_parameters()}")
# Step 4: Prepare long-document training data
# Use documents that are genuinely long (books, papers, code repos)
# Each example should fill or exceed the target context length
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Step 5: Configure training for long sequences
sft_config = SFTConfig(
output_dir="./checkpoints/long-context-llama",
max_seq_length=32768, # Full target length
per_device_train_batch_size=1, # Long sequences need BS=1
gradient_accumulation_steps=16, # Effective batch = 16
num_train_epochs=1, # 1 epoch is usually enough
learning_rate=2e-5,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
bf16=True,
gradient_checkpointing=True,
logging_steps=5,
save_steps=50,
packing=True,
)
# trainer = SFTTrainer(
# model=model,
# args=sft_config,
# train_dataset=long_document_dataset,
# processing_class=tokenizer,
# )
# trainer.train()
Memory requirements scale quadratically without Flash Attention. Standard attention requires O(n^2) memory where n is the sequence length. A 32K sequence uses 16x the memory of an 8K sequence. Flash Attention 2 reduces this to O(n) by computing attention in blocks without materializing the full attention matrix. Always use Flash Attention when training with long sequences; without it, even an 80GB A100 cannot handle sequences beyond approximately 16K tokens for a 7B model.
4. Chunking Strategies
When documents exceed even the extended context window, or when context extension is not practical, you need chunking strategies to process long texts in pieces. The quality of your chunking strategy has a significant impact on downstream performance, particularly for retrieval and summarization tasks.
4.1 Common Chunking Approaches
from typing import List
def chunk_with_overlap(
text: str,
chunk_size: int = 512,
overlap: int = 64,
tokenizer=None,
) -> List[dict]:
"""Chunk text with token-level overlap."""
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer.encode(text, add_special_tokens=False)
chunks = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
chunks.append({
"text": chunk_text,
"start_token": start,
"end_token": end,
"num_tokens": len(chunk_tokens),
})
# Move forward by (chunk_size - overlap)
start += chunk_size - overlap
# Stop if we have reached the end
if end >= len(tokens):
break
return chunks
def semantic_chunk(
text: str,
max_chunk_tokens: int = 512,
tokenizer=None,
) -> List[dict]:
"""Split text at semantic boundaries (paragraphs, sections)."""
import re
# Split on paragraph boundaries (double newlines) and section headers
segments = re.split(r'\n\n+|\n(?=#)', text)
segments = [s.strip() for s in segments if s.strip()]
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = []
current_segments = []
current_tokens = 0
for segment in segments:
seg_tokens = len(tokenizer.encode(segment, add_special_tokens=False))
if current_tokens + seg_tokens > max_chunk_tokens and current_segments:
# Save current chunk and start a new one
chunks.append({
"text": "\n\n".join(current_segments),
"num_tokens": current_tokens,
"num_segments": len(current_segments),
})
current_segments = []
current_tokens = 0
current_segments.append(segment)
current_tokens += seg_tokens
# Save the last chunk
if current_segments:
chunks.append({
"text": "\n\n".join(current_segments),
"num_tokens": current_tokens,
"num_segments": len(current_segments),
})
return chunks
# Example
sample_text = "First paragraph about topic A. " * 50 + "\n\n" + \
"Second paragraph about topic B. " * 30 + "\n\n" + \
"Third paragraph wrapping up. " * 20
overlap_chunks = chunk_with_overlap(sample_text, chunk_size=128, overlap=32)
semantic_chunks = semantic_chunk(sample_text, max_chunk_tokens=256)
print(f"Overlap chunking: {len(overlap_chunks)} chunks")
print(f"Semantic chunking: {len(semantic_chunks)} chunks")
Overlap size matters. Too little overlap (less than 10% of chunk size) risks losing context at boundaries. Too much overlap (more than 30%) creates excessive duplication and increases storage and compute costs. A good default is 10% to 20% overlap. For retrieval tasks, err toward more overlap; for summarization, less overlap usually suffices since each chunk is processed independently.
5. The Lost-in-the-Middle Phenomenon
Even models that can technically process long contexts do not attend to all positions equally. Research has shown that language models exhibit a "lost-in-the-middle" effect: they attend strongly to information at the beginning and end of the context window but struggle to recall or use information in the middle. This has significant implications for how you structure long prompts and which context extension strategies you choose.
# Practical strategies for mitigating lost-in-the-middle
from typing import List, Dict
def reorder_context_for_retrieval(
query: str,
retrieved_passages: List[Dict],
strategy: str = "important_first_last"
) -> List[Dict]:
"""Reorder passages to mitigate the lost-in-the-middle effect."""
if strategy == "important_first_last":
# Place the most relevant passages at the start and end
# Less relevant passages go in the middle
sorted_passages = sorted(
retrieved_passages,
key=lambda p: p["relevance_score"],
reverse=True
)
n = len(sorted_passages)
reordered = [None] * n
# Alternate between start and end positions
left, right = 0, n - 1
for i, passage in enumerate(sorted_passages):
if i % 2 == 0:
reordered[left] = passage
left += 1
else:
reordered[right] = passage
right -= 1
return reordered
elif strategy == "reverse_rank":
# Put least relevant first, most relevant last
# (recency bias helps with last items)
return sorted(
retrieved_passages,
key=lambda p: p["relevance_score"]
)
return retrieved_passages
# Example: 10 passages ranked by relevance
passages = [
{"text": f"Passage {i}", "relevance_score": 1.0 - i * 0.1}
for i in range(10)
]
reordered = reorder_context_for_retrieval("query", passages)
positions = [(p["text"], f"score={p['relevance_score']:.1f}") for p in reordered]
for i, (text, score) in enumerate(positions):
position_label = "START" if i < 2 else "END" if i >= 8 else "middle"
print(f" Position {i:2d} [{position_label:6s}]: {text} ({score})")
Structure your prompts with the U-shape in mind. Place the most critical information (key instructions, the most relevant retrieved passages, essential context) at the very beginning and the very end of your prompt. Less critical supporting information can go in the middle. This simple reordering can improve retrieval accuracy by 10% to 20% on long-context tasks without any model changes.
Section 13.7 Quiz
Show Answer
Show Answer
Show Answer
Show Answer
Show Answer
Key Takeaways
- Position encoding is the bottleneck: models fail on sequences longer than their training window because positional encoding values are out of distribution.
- RoPE scaling methods (linear, dynamic NTK, YaRN) can extend context by 2x to 4x with minimal or no fine-tuning; larger extensions require continued pre-training.
- LongLoRA makes long-context fine-tuning affordable by combining LoRA adapters with shifted sparse attention during training.
- Flash Attention 2 is mandatory for long-context work because standard attention has quadratic memory requirements that quickly exceed GPU capacity.
- Chunking with overlap (10% to 20%) is the practical fallback when documents exceed the context window; semantic chunking at natural boundaries produces the highest quality.
- The lost-in-the-middle effect means models recall beginning and end information best; structure prompts accordingly by placing critical content at boundary positions.