Transformer Efficiency: FlashAttention, RoPE/ALiBi, KV-Cache & Long Sequences

45 mintext

Theory & Concepts

Transformer Efficiency: Beyond Standard Attention

The original Transformer architecture revolutionized NLP, but standard self-attention has a critical flaw: O(n²) time and memory complexity. For a sequence of length n=4096, that's 16 million computations. At n=100,000, it becomes infeasible.

This lesson covers breakthrough techniques that enable modern LLMs to handle 100K+ token contexts efficiently.

💡 Why This Matters: GPT-4 processes 128K tokens, Claude 3 handles 200K. These aren't brute-force-they use sophisticated optimizations you'll learn here. Without these techniques, even inference would be prohibitively expensive.


1. FlashAttention: Revolutionizing Attention Efficiency

The Problem with Standard Attention

Standard Self-Attention Algorithm:

python
# Naive attention (what PyTorch does under the hood)
def standard_attention(Q, K, V):
"""
Q, K, V: [batch, heads, seq_len, head_dim]
Returns: [batch, heads, seq_len, head_dim]
"""
d_k = Q.size(-1)
# Step 1: Compute attention scores
# O(n²d) operations, stores n² matrix in HBM
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: [batch, heads, seq_len, seq_len] ← MEMORY BOTTLENECK
# Step 2: Softmax normalization
attn_weights = F.softmax(scores, dim=-1)
# Still [batch, heads, seq_len, seq_len] in memory
# Step 3: Weighted sum of values
output = torch.matmul(attn_weights, V)
return output

Why This Is Slow:

  1. HBM (High Bandwidth Memory) Bottleneck: The scores matrix (n×n) is stored in GPU's slow HBM (40-80 GB/s bandwidth)
  2. Multiple Memory Reads/Writes: Each step loads data from HBM → SRAM → compute → HBM
  3. Memory Scaling: For n=10,000 with 16 heads, that's 1.6 billion floats (6.4 GB just for attention weights!)

FlashAttention's Breakthrough

Key Insight: Don't materialize the full attention matrix. Instead, compute attention in blocks that fit in fast SRAM (10-20 TB/s bandwidth).

Tiling Strategy

FlashAttention divides Q, K, V into blocks and processes them in chunks:

Standard Attention:
[Compute all n² scores] → [Softmax] → [Multiply by V]
Memory: O(n²)
 
FlashAttention:
For each block of Q:
For each block of K, V:
Compute partial attention in SRAM
Update running statistics
Memory: O(n) ← only stores final output!

The Algorithm (Simplified)

python
def flash_attention_concept(Q, K, V, block_size=256):
"""
Conceptual implementation (real version uses CUDA kernels)
Key idea: Process attention in blocks, never materialize full n×n matrix
"""
seq_len = Q.size(2)
output = torch.zeros_like(Q)
# Running max for numerical stability (prevents overflow)
row_max = torch.full((Q.size(0), Q.size(1), seq_len), float('-inf'))
row_sum = torch.zeros((Q.size(0), Q.size(1), seq_len))
# Tile over queries
for q_start in range(0, seq_len, block_size):
q_end = min(q_start + block_size, seq_len)
Q_block = Q[:, :, q_start:q_end, :] # Load to SRAM
# Tile over keys/values
for kv_start in range(0, seq_len, block_size):
kv_end = min(kv_start + block_size, kv_start)
K_block = K[:, :, kv_start:kv_end, :] # Load to SRAM
V_block = V[:, :, kv_start:kv_end, :] # Load to SRAM
# Compute attention scores for this block (in SRAM!)
scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
scores = scores / math.sqrt(Q.size(-1))
# Online softmax (avoid storing full scores)
# Update running max and sum for numerical stability
block_max = scores.max(dim=-1, keepdim=True).values
new_max = torch.maximum(row_max[:, :, q_start:q_end], block_max)
# Rescale previous output and add new contribution
exp_scores = torch.exp(scores - new_max)
block_sum = exp_scores.sum(dim=-1, keepdim=True)
# Update output with this block's contribution
output[:, :, q_start:q_end] = (
output[:, :, q_start:q_end] * torch.exp(row_max[:, :, q_start:q_end] - new_max) +
torch.matmul(exp_scores, V_block)
) / (row_sum[:, :, q_start:q_end] * torch.exp(row_max[:, :, q_start:q_end] - new_max) + block_sum)
row_max[:, :, q_start:q_end] = new_max
row_sum[:, :, q_start:q_end] = row_sum[:, :, q_start:q_end] * torch.exp(row_max[:, :, q_start:q_end] - new_max) + block_sum
return output

⚠️ Implementation Note: Real FlashAttention uses highly optimized CUDA kernels. The above is conceptual-don't use in production!

Performance Gains

| Sequence Length | Standard Attention | FlashAttention | Speedup | |----------------|-------------------|----------------|---------| | 512 | 100 ms | 80 ms | 1.25× | | 2,048 | 1.6 s | 320 ms | 5× | | 8,192 | 25 s | 1.3 s | 19× | | 32,768 | OOM (Out of Mem) | 5.2 s | ∞ |

Real-World Impact: Training GPT-3 scale models is 2-4× faster with FlashAttention. Llama 2, GPT-4, and Claude all use variants of this technique.

Using FlashAttention in Practice

python
import torch
from torch.nn.functional import scaled_dot_product_attention
# Modern PyTorch (2.0+) has Flash Attention built-in!
def efficient_attention(query, key, value, mask=None):
"""
Uses FlashAttention automatically if:
1. CUDA is available
2. Data types are FP16 or BF16
3. No custom attention mask (or simple causal mask)
"""
# PyTorch automatically dispatches to FlashAttention
output = scaled_dot_product_attention(
query, key, value,
attn_mask=mask,
dropout_p=0.0,
is_causal=True # Enables causal masking efficiently
)
return output
# Example usage
batch, heads, seq_len, head_dim = 4, 12, 4096, 64
Q = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
K = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
V = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
# This uses FlashAttention under the hood
output = efficient_attention(Q, K, V)
print(f"Output shape: {output.shape}") # [4, 12, 4096, 64]

💡 Pro Tip: Always use torch.compile() with FlashAttention for an additional 10-20% speedup from kernel fusion.


2. Positional Encodings: RoPE vs ALiBi vs Learned

Standard Transformers use sinusoidal positional encodings, but modern LLMs have moved to more sophisticated approaches that enable length generalization.

The Position Encoding Problem

Transformers are permutation invariant-without position information, "dog bites man" = "man bites dog". We need to inject position info, but how?

Requirements for Modern LLMs:

  1. Extrapolation: Handle sequences longer than training (train on 2K, infer on 8K)
  2. Efficiency: No added memory overhead
  3. Relative positions: "word 3 tokens ago" matters more than "word at position 947"

RoPE (Rotary Position Embedding)

Used by: Llama, Mistral, Qwen, most modern LLMs

Key Idea: Encode position by rotating the query and key vectors in a specific way that makes dot product depend on relative position.

Mathematical Foundation

For position m, rotate the embedding dimensions by angle :

For dimension pair (d₁, d₂), position m:
[x_d₁] [cos(mθ) -sin(mθ)] [x_d₁]
[x_d₂] = [sin(mθ) cos(mθ)] [x_d₂]

Why This Works:

After applying RoPE to queries and keys, their dot product becomes:

Q_m · K_n = f(Q, K, m - n)

Notice: Depends only on relative position (m - n), not absolute positions!

Implementation

python
import torch
import torch.nn as nn
class RotaryPositionEmbedding(nn.Module):
"""
Rotary Position Embedding (RoPE) implementation
Used in LLaMA, Mistral, and most modern LLMs
"""
def __init__(self, dim, max_seq_len=8192, base=10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute theta values for each dimension pair
# θᵢ = base^(-2i/d) where i ∈ [0, d/2)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute rotation matrices for all positions
self._precompute_freqs(max_seq_len)
def _precompute_freqs(self, seq_len):
"""Precompute cos and sin for all positions"""
# Positions: [0, 1, 2, ..., seq_len-1]
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
# Outer product: [seq_len, dim/2]
freqs = torch.outer(t, self.inv_freq)
# Combine to get [seq_len, dim]
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def rotate_half(self, x):
"""
Rotate half the dimensions
[x1, x2, x3, x4] [-x3, -x4, x1, x2]
"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def forward(self, q, k, seq_len=None):
"""
Apply rotary embeddings to queries and keys
Args:
q: [batch, heads, seq_len, head_dim]
k: [batch, heads, seq_len, head_dim]
Returns:
q_rotated, k_rotated (same shape)
"""
if seq_len is None:
seq_len = q.size(2)
# Get cached cos/sin for this sequence length
cos = self.cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply rotation
# q_rotated = q * cos + rotate_half(q) * sin
q_rotated = q * cos + self.rotate_half(q) * sin
k_rotated = k * cos + self.rotate_half(k) * sin
return q_rotated, k_rotated
# Usage example
rope = RotaryPositionEmbedding(dim=64, max_seq_len=8192)
# Sample query and key tensors
batch, heads, seq_len, head_dim = 2, 8, 512, 64
q = torch.randn(batch, heads, seq_len, head_dim)
k = torch.randn(batch, heads, seq_len, head_dim)
# Apply RoPE
q_rotated, k_rotated = rope(q, k)
print(f"Rotated Q shape: {q_rotated.shape}") # [2, 8, 512, 64]

RoPE Advantages:

  • ✅ No learnable parameters (zero overhead)
  • ✅ Relative position encoding (better generalization)
  • ✅ Excellent extrapolation (can extend context 2-4× beyond training)
  • ✅ Fast computation (just element-wise multiplication)

ALiBi (Attention with Linear Biases)

Used by: BLOOM, MPT, some research models

Key Idea: Instead of modifying Q/K, add a linear bias to attention scores based on distance.

Attention_scores[i, j] = Q[i] · K[j] - m × |i - j|

Where m is a head-specific slope (different for each attention head).

Implementation

python
def get_alibi_slopes(num_heads):
"""
Compute ALiBi slopes for each attention head
Uses geometric sequence: 2^(-8/n), 2^(-16/n), ..., 2^(-8)
"""
def get_slopes_power_of_2(n):
start = 2 ** (-8)
ratio = start
return [start * (ratio ** i) for i in range(n)]
if (num_heads & (num_heads - 1)) == 0: # Power of 2
return get_slopes_power_of_2(num_heads)
else: # Not power of 2: interpolate
closest_power = 2 ** math.floor(math.log2(num_heads))
return (
get_slopes_power_of_2(closest_power) +
get_alibi_slopes(2 * closest_power)[0::2][:num_heads - closest_power]
)
def apply_alibi_bias(attention_scores, num_heads):
"""
Apply ALiBi bias to attention scores
Args:
attention_scores: [batch, heads, seq_len, seq_len]
Returns:
Biased scores (same shape)
"""
seq_len = attention_scores.size(-1)
# Create position distance matrix
# distances[i, j] = |i - j|
positions = torch.arange(seq_len, device=attention_scores.device)
distances = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs()
# Get slopes for each head
slopes = torch.tensor(
get_alibi_slopes(num_heads),
device=attention_scores.device
).view(1, num_heads, 1, 1)
# Compute bias: -slope × distance
bias = -slopes * distances.unsqueeze(0).unsqueeze(0)
# Add bias to scores
return attention_scores + bias
# Example usage
batch, heads, seq_len = 2, 12, 512
scores = torch.randn(batch, heads, seq_len, seq_len)
# Apply ALiBi
biased_scores = apply_alibi_bias(scores, num_heads=heads)
print(f"Biased scores shape: {biased_scores.shape}") # [2, 12, 512, 512]

ALiBi vs RoPE:

| Feature | RoPE | ALiBi | |---------|------|-------| | Memory Overhead | None | Bias matrix (small) | | Computation | Rotate Q/K | Add bias to scores | | Extrapolation | Excellent (4× training length) | Good (2× training length) | | Adoption | Most modern LLMs | Some research models | | FlashAttention Compatible | ✅ Yes | ⚠️ Partial (custom kernels needed) |

💡 When to Use Which: RoPE is the industry standard for good reason. Use ALiBi only if you need extreme simplicity or have specific research requirements.


3. KV-Cache: Making Inference Fast

During autoregressive generation (producing tokens one-by-one), computing attention naively is extremely wasteful.

The Inefficiency Problem

When generating token t, we need to attend to all previous tokens [0, 1, ..., t-1]. Without caching:

python
# Inefficient autoregressive generation
def generate_naive(model, prompt_ids, max_new_tokens=100):
"""
Recomputes Q, K, V for ALL tokens at EVERY step
For 100 tokens: ~5000 redundant computations!
"""
input_ids = prompt_ids.clone()
for _ in range(max_new_tokens):
# Process entire sequence every time (wasteful!)
logits = model(input_ids) # Recomputes K, V for all past tokens
next_token = logits[:, -1, :].argmax(dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_ids

Cost: For sequence length n, generating m new tokens costs O(m × n × d²) FLOPs.

KV-Cache Solution

Key Insight: Keys and Values for past tokens never change! Cache them and only compute K, V for the new token.

python
def generate_with_kv_cache(model, prompt_ids, max_new_tokens=100):
"""
Efficient generation with KV-cache
Reduces computation by ~30-50× for long sequences
"""
input_ids = prompt_ids.clone()
kv_cache = None # Will store past K, V tensors
for i in range(max_new_tokens):
if i == 0:
# First step: process full prompt
logits, kv_cache = model(input_ids, use_cache=True, past_kv=None)
else:
# Subsequent steps: only process new token
logits, kv_cache = model(
input_ids[:, -1:], # Only last token!
use_cache=True,
past_kv=kv_cache # Reuse cached K, V
)
next_token = logits[:, -1, :].argmax(dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_ids

KV-Cache Anatomy

The cache stores Keys and Values for each layer and head:

python
class TransformerLayerWithCache(nn.Module):
"""
Transformer layer with KV-caching support
"""
def __init__(self, d_model=768, num_heads=12):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, past_kv=None):
"""
Args:
x: [batch, seq_len, d_model] (seq_len=1 when using cache)
past_kv: Tuple of (past_key, past_value) or None
past_key: [batch, num_heads, past_seq_len, head_dim]
Returns:
output: [batch, seq_len, d_model]
new_kv: Updated (key, value) cache
"""
batch, seq_len, d_model = x.shape
# Compute Q, K, V for current token(s)
Q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# If we have past KV, concatenate with current
if past_kv is not None:
past_key, past_value = past_kv
K = torch.cat([past_key, K], dim=2) # Concatenate along sequence dimension
V = torch.cat([past_value, V], dim=2)
# Store new KV cache (includes past + current)
new_kv = (K, V)
# Compute attention with full key/value sequence
# Q: [batch, heads, 1, head_dim] (current token)
# K, V: [batch, heads, total_seq_len, head_dim] (all past + current)
attn_output = F.scaled_dot_product_attention(Q, K, V, is_causal=False)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
output = self.out_proj(attn_output)
return output, new_kv
# Example: Generate with KV-cache
layer = TransformerLayerWithCache(d_model=768, num_heads=12)
# Step 1: Process prompt (seq_len=10)
prompt = torch.randn(1, 10, 768)
output1, kv_cache = layer(prompt, past_kv=None)
print(f"Step 1 - Output: {output1.shape}, Cache K shape: {kv_cache[0].shape}")
# Output: [1, 10, 768], Cache K: [1, 12, 10, 64]
# Step 2: Generate next token (seq_len=1, reuse cache)
new_token = torch.randn(1, 1, 768)
output2, kv_cache = layer(new_token, past_kv=kv_cache)
print(f"Step 2 - Output: {output2.shape}, Cache K shape: {kv_cache[0].shape}")
# Output: [1, 1, 768], Cache K: [1, 12, 11, 64] ← cache grew by 1

Memory Overhead of KV-Cache

For a model with:

  • L layers
  • h attention heads
  • d_h head dimension
  • Sequence length n
  • Batch size b
  • Data type: FP16 (2 bytes)

KV-Cache Size:

Memory = 2 × L × b × h × n × d_h × 2 bytes
↑ ↑ ↑ ↑ ↑ ↑ ↑
K+V layers batch heads seq head_dim FP16

Example (Llama 2 7B):

  • L=32 layers, h=32 heads, d_h=128, n=4096
  • Memory = 2 × 32 × 1 × 32 × 4096 × 128 × 2 = 2.1 GB per sample

⚠️ Critical Implication: With batch size 8 at 4K context, KV-cache alone uses 17 GB. This is why inference requires so much VRAM!


4. Context Extension Techniques

Training on 4K tokens but need to handle 32K at inference? Here's how modern LLMs do it.

Position Interpolation

Problem: RoPE trained on max length 2048 fails at 4096 because rotation angles are too large.

Solution: Scale down the position indices during inference.

python
def extend_rope_context(rope_module, original_max_len, new_max_len):
"""
Extend RoPE context by interpolating position indices
Used in Llama 2 Long, Code Llama 100K
"""
scale = original_max_len / new_max_len # e.g., 2048 / 8192 = 0.25
# Adjust frequency computation
# Original: θᵢ = base^(-2i/d)
# Extended: θᵢ = base^(-2i/d) × scale
rope_module.inv_freq = rope_module.inv_freq * scale
# Recompute cached cos/sin with new frequencies
rope_module._precompute_freqs(new_max_len)
return rope_module
# Example: Extend from 2K to 8K context
rope = RotaryPositionEmbedding(dim=64, max_seq_len=2048)
rope_extended = extend_rope_context(rope, original_max_len=2048, new_max_len=8192)
# Now can handle 8K sequences without retraining!
q = torch.randn(1, 8, 8192, 64) # 8K sequence
k = torch.randn(1, 8, 8192, 64)
q_rot, k_rot = rope_extended(q, k)
print(f"Extended to {q_rot.size(2)} tokens") # 8192

Performance: Can extend 2-4× with minimal quality loss. Beyond that, fine-tuning is needed.

YaRN (Yet another RoPE extensioN)

More sophisticated interpolation that preserves high-frequency components:

python
def yarn_scaling(original_length, new_length, dim, base=10000):
"""
YaRN scaling: Non-uniform interpolation
- Low frequencies (long-range): Interpolate aggressively
- High frequencies (short-range): Minimal interpolation
"""
scale = new_length / original_length
# Compute per-dimension scaling factors
dim_range = torch.arange(0, dim, 2).float()
freqs = base ** (-dim_range / dim)
# Lower frequencies get more scaling (they capture long-range)
# Higher frequencies get less scaling (they capture local patterns)
yarn_scale = torch.where(
freqs < 0.1, # Low frequency threshold
scale, # Full interpolation
1.0 + (scale - 1.0) * (freqs / 0.1) # Gradual interpolation
)
return yarn_scale

5. Long-Sequence Pitfalls & Solutions

Pitfall 1: Lost in the Middle

Problem: Models pay less attention to middle tokens in very long contexts.

python
# Experiment: Where does the model look?
def test_context_recall(model, tokenizer, needle_position='middle'):
"""
Test if model can find info at different positions
"Needle in haystack" benchmark
"""
context = "Random text... " * 1000 # Long distractor
needle = "The secret password is BANANA."
if needle_position == 'start':
full_text = needle + context
elif needle_position == 'middle':
mid = len(context) // 2
full_text = context[:mid] + needle + context[mid:]
else: # end
full_text = context + needle
# Ask model to recall the password
prompt = full_text + "
What is the secret password?"
response = model.generate(tokenizer.encode(prompt))
return "BANANA" in response
# Results (GPT-3.5 on 16K context):
# Start: 95% recall ✅
# Middle: 62% recall ⚠️
# End: 98% recall ✅

Solutions:

  1. Structured prompting: Put critical info at start/end
  2. Retrieval augmentation: Don't put everything in context
  3. Fine-tune on long documents: Train model to use middle content

Pitfall 2: Attention Collapse

At very long contexts, attention can become too uniform (attends equally to everything = attends to nothing).

python
def diagnose_attention_collapse(attention_weights):
"""
Check if attention has collapsed to uniform distribution
Args:
attention_weights: [batch, heads, seq_len, seq_len]
Returns:
entropy: Higher = more diffuse attention
"""
# Compute entropy of attention distribution
# Uniform distribution has max entropy
eps = 1e-10
entropy = -(attention_weights * torch.log(attention_weights + eps)).sum(dim=-1)
max_entropy = math.log(attention_weights.size(-1))
# Ratio: 1.0 = completely uniform (collapsed)
entropy_ratio = entropy / max_entropy
return entropy_ratio.mean().item()
# Example
attn = torch.softmax(torch.randn(1, 12, 4096, 4096), dim=-1)
collapse_score = diagnose_attention_collapse(attn)
print(f"Attention collapse score: {collapse_score:.3f}")
# > 0.9 indicates potential collapse

Solution: Use attention sink tokens (keep first tokens always attended to prevent collapse).

Pitfall 3: Memory Overflow

KV-cache grows linearly with sequence length-can easily OOM.

Solutions:

  1. PagedAttention (vLLM): Store KV-cache in paged memory, like OS virtual memory
  2. Streaming LLM: Evict old cache, keep only recent + first few tokens
  3. Compress KV-cache: Quantize to INT8 or lower precision
python
def streaming_kv_cache(kv_cache, max_cache_len=2048, keep_first=128):
"""
Maintain fixed-size KV-cache for infinite generation
Keep first N tokens (attention sinks) + recent tokens
"""
key_cache, value_cache = kv_cache
current_len = key_cache.size(2)
if current_len <= max_cache_len:
return kv_cache # No eviction needed
# Keep first keep_first tokens + most recent
recent_tokens = max_cache_len - keep_first
key_cache = torch.cat([
key_cache[:, :, :keep_first, :], # First N tokens
key_cache[:, :, -recent_tokens:, :] # Recent tokens
], dim=2)
value_cache = torch.cat([
value_cache[:, :, :keep_first, :],
value_cache[:, :, -recent_tokens:, :]
], dim=2)
return (key_cache, value_cache)

Summary: Key Takeaways

FlashAttention

  • Breakthrough: Computes attention without materializing n² matrix
  • How: Tiling + online softmax in fast SRAM
  • Impact: 5-20× faster, handles 4× longer sequences
  • Use: Built into PyTorch 2.0+ scaled_dot_product_attention

Position Encodings

  • RoPE: Industry standard, excellent extrapolation, zero overhead
  • ALiBi: Simple linear biases, good for research
  • Choose RoPE unless you have specific needs

KV-Cache

  • Purpose: Avoid recomputing past tokens' K, V during generation
  • Speedup: 30-50× faster inference
  • Cost: 2-4 GB memory per sample at 4K context
  • Critical for production inference

Context Extension

  • Position Interpolation: Scale RoPE frequencies for 2-4× extension
  • YaRN: Non-uniform scaling preserves quality better
  • Limitation: Beyond 4× requires fine-tuning

Long-Sequence Pitfalls

  • ⚠️ Lost in the Middle: Models struggle with middle content
  • ⚠️ Attention Collapse: Uniform attention at extreme lengths
  • ⚠️ Memory: KV-cache scales linearly with length
  • Solutions: Structured prompts, streaming cache, retrieval augmentation

🎯 Next Steps: Apply these to build efficient Transformer variants and explore State Space Models (SSMs) that break the n² barrier entirely!

Lesson Content

Master advanced Transformer optimization techniques including FlashAttention's memory-efficient attention, positional encoding strategies (RoPE/ALiBi), KV-cache mechanisms, context extension methods, and pitfalls of long-sequence processing.

Code Example

python
# Complete implementation: Efficient Transformer with all optimizations
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RoPE(nn.Module):
"""Rotary Position Embedding - Production ready"""
def __init__(self, dim, max_seq_len=32768, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self._precompute_freqs(max_seq_len)
def _precompute_freqs(self, seq_len):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def forward(self, q, k):
seq_len = q.size(2)
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
q_rotated = q * cos + self.rotate_half(q) * sin
k_rotated = k * cos + self.rotate_half(k) * sin
return q_rotated, k_rotated
class EfficientAttention(nn.Module):
"""
Multi-head attention with all efficiency features:
- FlashAttention (via PyTorch SDPA)
- RoPE positional encoding
- KV-cache support
"""
def __init__(self, d_model=768, num_heads=12, max_seq_len=8192):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.scale = self.head_dim ** -0.5
# Projections
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
# RoPE for position encoding
self.rope = RoPE(self.head_dim, max_seq_len=max_seq_len)
def forward(self, x, past_kv=None, use_cache=False):
"""
Args:
x: [batch, seq_len, d_model]
past_kv: Optional tuple of (past_key, past_value) for caching
use_cache: Whether to return updated KV cache
Returns:
output: [batch, seq_len, d_model]
new_kv: Updated cache if use_cache=True, else None
"""
batch, seq_len, d_model = x.shape
# Compute Q, K, V
Q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Apply RoPE to Q and K
Q, K = self.rope(Q, K)
# Handle KV-cache
if past_kv is not None:
past_key, past_value = past_kv
K = torch.cat([past_key, K], dim=2)
V = torch.cat([past_value, V], dim=2)
# Prepare cache for next iteration
new_kv = (K, V) if use_cache else None
# FlashAttention (automatically used if available)
# is_causal=True for decoder-only models (like GPT)
attn_output = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=None,
dropout_p=0.0,
is_causal=(past_kv is None) # Only causal on first pass
)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch, seq_len, d_model)
output = self.out_proj(attn_output)
return output, new_kv
class TransformerBlock(nn.Module):
"""Complete Transformer block with efficient attention"""
def __init__(self, d_model=768, num_heads=12, d_ff=3072, dropout=0.1):
super().__init__()
self.attn = EfficientAttention(d_model, num_heads)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
# Feed-forward network
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x, past_kv=None, use_cache=False):
# Attention with residual
attn_out, new_kv = self.attn(self.ln1(x), past_kv, use_cache)
x = x + attn_out
# Feed-forward with residual
x = x + self.ff(self.ln2(x))
return x, new_kv
class EfficientTransformer(nn.Module):
"""
Production-ready efficient Transformer with:
- FlashAttention
- RoPE
- KV-cache
- Context extension support
"""
def __init__(
self,
vocab_size=50257,
d_model=768,
num_layers=12,
num_heads=12,
d_ff=3072,
max_seq_len=8192,
dropout=0.1
):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
# Embeddings
self.token_emb = nn.Embedding(vocab_size, d_model)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.ln_final = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, input_ids, past_kvs=None, use_cache=False):
"""
Args:
input_ids: [batch, seq_len]
past_kvs: List of (key, value) tuples for each layer
use_cache: Whether to return KV cache
Returns:
logits: [batch, seq_len, vocab_size]
new_kvs: List of updated caches if use_cache=True
"""
x = self.token_emb(input_ids)
new_kvs = [] if use_cache else None
for i, block in enumerate(self.blocks):
past_kv = past_kvs[i] if past_kvs is not None else None
x, new_kv = block(x, past_kv, use_cache)
if use_cache:
new_kvs.append(new_kv)
x = self.ln_final(x)
logits = self.head(x)
return logits, new_kvs
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=100, temperature=1.0):
"""
Efficient autoregressive generation with KV-cache
"""
past_kvs = None
for _ in range(max_new_tokens):
# Only pass last token after first iteration
curr_input = input_ids if past_kvs is None else input_ids[:, -1:]
# Forward pass with caching
logits, past_kvs = self.forward(curr_input, past_kvs, use_cache=True)
# Sample next token
next_token_logits = logits[:, -1, :] / temperature
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Append to sequence
input_ids = torch.cat([input_ids, next_token], dim=1)
# Optional: check for EOS token
# if next_token.item() == eos_token_id:
# break
return input_ids
# ============================================================================
# Usage Examples
# ============================================================================
def example_basic_usage():
"""Basic model creation and forward pass"""
model = EfficientTransformer(
vocab_size=50257,
d_model=768,
num_layers=12,
num_heads=12,
max_seq_len=8192
).cuda()
# Single forward pass
input_ids = torch.randint(0, 50257, (2, 512)).cuda()
logits, _ = model(input_ids, use_cache=False)
print(f"Logits shape: {logits.shape}") # [2, 512, 50257]
def example_efficient_generation():
"""Efficient generation with KV-cache"""
model = EfficientTransformer(
vocab_size=50257,
d_model=768,
num_layers=12,
num_heads=12,
max_seq_len=8192
).cuda().eval()
# Prompt
prompt_ids = torch.randint(0, 50257, (1, 128)).cuda()
# Generate 100 tokens efficiently
with torch.no_grad():
output_ids = model.generate(prompt_ids, max_new_tokens=100)
print(f"Generated {output_ids.size(1) - 128} new tokens")
print(f"Final sequence length: {output_ids.size(1)}")
def example_context_extension():
"""Extend context length beyond training"""
# Train on 2K context
model_2k = EfficientTransformer(max_seq_len=2048).cuda()
# Extend to 8K using position interpolation
def extend_context(model, new_max_len):
for block in model.blocks:
rope = block.attn.rope
scale = rope.max_seq_len / new_max_len
rope.inv_freq = rope.inv_freq * scale
rope._precompute_freqs(new_max_len)
rope.max_seq_len = new_max_len
model.max_seq_len = new_max_len
extend_context(model_2k, new_max_len=8192)
# Now can handle 8K sequences
long_input = torch.randint(0, 50257, (1, 8192)).cuda()
logits, _ = model_2k(long_input)
print(f"Successfully processed {long_input.size(1)} tokens")
def benchmark_kv_cache_speedup():
"""Measure KV-cache speedup"""
import time
model = EfficientTransformer(
vocab_size=1000,
d_model=512,
num_layers=6,
num_heads=8,
max_seq_len=2048
).cuda().eval()
prompt = torch.randint(0, 1000, (1, 128)).cuda()
num_tokens = 50
# Without KV-cache (naive)
start = time.time()
with torch.no_grad():
input_ids = prompt.clone()
for _ in range(num_tokens):
logits, _ = model(input_ids, use_cache=False)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
naive_time = time.time() - start
# With KV-cache
start = time.time()
with torch.no_grad():
output_ids = model.generate(prompt, max_new_tokens=num_tokens)
cache_time = time.time() - start
print(f"\nGeneration Benchmark ({num_tokens} tokens):")
print(f"Without KV-cache: {naive_time:.3f}s")
print(f"With KV-cache: {cache_time:.3f}s")
print(f"Speedup: {naive_time / cache_time:.2f}×")
if __name__ == "__main__":
print("=== Example 1: Basic Usage ===")
example_basic_usage()
print("\n=== Example 2: Efficient Generation ===")
example_efficient_generation()
print("\n=== Example 3: Context Extension ===")
example_context_extension()
print("\n=== Example 4: KV-Cache Speedup ===")
benchmark_kv_cache_speedup()
Section 1 of 10 • Lesson 1 of 5