AI Advanced: Modern Architectures, Efficiency & Scale
Transformer Efficiency: FlashAttention, RoPE/ALiBi, KV-Cache & Long Sequences
Theory & Concepts
Transformer Efficiency: Beyond Standard Attention
The original Transformer architecture revolutionized NLP, but standard self-attention has a critical flaw: O(n²) time and memory complexity. For a sequence of length n=4096, that's 16 million computations. At n=100,000, it becomes infeasible.
This lesson covers breakthrough techniques that enable modern LLMs to handle 100K+ token contexts efficiently.
💡 Why This Matters: GPT-4 processes 128K tokens, Claude 3 handles 200K. These aren't brute-force-they use sophisticated optimizations you'll learn here. Without these techniques, even inference would be prohibitively expensive.
1. FlashAttention: Revolutionizing Attention Efficiency
The Problem with Standard Attention
Standard Self-Attention Algorithm:
# Naive attention (what PyTorch does under the hood)def standard_attention(Q, K, V): """ Q, K, V: [batch, heads, seq_len, head_dim] Returns: [batch, heads, seq_len, head_dim] """ d_k = Q.size(-1) # Step 1: Compute attention scores # O(n²d) operations, stores n² matrix in HBM scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # scores: [batch, heads, seq_len, seq_len] ← MEMORY BOTTLENECK # Step 2: Softmax normalization attn_weights = F.softmax(scores, dim=-1) # Still [batch, heads, seq_len, seq_len] in memory # Step 3: Weighted sum of values output = torch.matmul(attn_weights, V) return outputWhy This Is Slow:
- HBM (High Bandwidth Memory) Bottleneck: The
scoresmatrix (n×n) is stored in GPU's slow HBM (40-80 GB/s bandwidth) - Multiple Memory Reads/Writes: Each step loads data from HBM → SRAM → compute → HBM
- Memory Scaling: For n=10,000 with 16 heads, that's 1.6 billion floats (6.4 GB just for attention weights!)
FlashAttention's Breakthrough
Key Insight: Don't materialize the full attention matrix. Instead, compute attention in blocks that fit in fast SRAM (10-20 TB/s bandwidth).
Tiling Strategy
FlashAttention divides Q, K, V into blocks and processes them in chunks:
Standard Attention: [Compute all n² scores] → [Softmax] → [Multiply by V] Memory: O(n²) FlashAttention: For each block of Q: For each block of K, V: Compute partial attention in SRAM Update running statistics Memory: O(n) ← only stores final output!The Algorithm (Simplified)
def flash_attention_concept(Q, K, V, block_size=256): """ Conceptual implementation (real version uses CUDA kernels) Key idea: Process attention in blocks, never materialize full n×n matrix """ seq_len = Q.size(2) output = torch.zeros_like(Q) # Running max for numerical stability (prevents overflow) row_max = torch.full((Q.size(0), Q.size(1), seq_len), float('-inf')) row_sum = torch.zeros((Q.size(0), Q.size(1), seq_len)) # Tile over queries for q_start in range(0, seq_len, block_size): q_end = min(q_start + block_size, seq_len) Q_block = Q[:, :, q_start:q_end, :] # Load to SRAM # Tile over keys/values for kv_start in range(0, seq_len, block_size): kv_end = min(kv_start + block_size, kv_start) K_block = K[:, :, kv_start:kv_end, :] # Load to SRAM V_block = V[:, :, kv_start:kv_end, :] # Load to SRAM # Compute attention scores for this block (in SRAM!) scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) scores = scores / math.sqrt(Q.size(-1)) # Online softmax (avoid storing full scores) # Update running max and sum for numerical stability block_max = scores.max(dim=-1, keepdim=True).values new_max = torch.maximum(row_max[:, :, q_start:q_end], block_max) # Rescale previous output and add new contribution exp_scores = torch.exp(scores - new_max) block_sum = exp_scores.sum(dim=-1, keepdim=True) # Update output with this block's contribution output[:, :, q_start:q_end] = ( output[:, :, q_start:q_end] * torch.exp(row_max[:, :, q_start:q_end] - new_max) + torch.matmul(exp_scores, V_block) ) / (row_sum[:, :, q_start:q_end] * torch.exp(row_max[:, :, q_start:q_end] - new_max) + block_sum) row_max[:, :, q_start:q_end] = new_max row_sum[:, :, q_start:q_end] = row_sum[:, :, q_start:q_end] * torch.exp(row_max[:, :, q_start:q_end] - new_max) + block_sum return output⚠️ Implementation Note: Real FlashAttention uses highly optimized CUDA kernels. The above is conceptual-don't use in production!
Performance Gains
| Sequence Length | Standard Attention | FlashAttention | Speedup | |----------------|-------------------|----------------|---------| | 512 | 100 ms | 80 ms | 1.25× | | 2,048 | 1.6 s | 320 ms | 5× | | 8,192 | 25 s | 1.3 s | 19× | | 32,768 | OOM (Out of Mem) | 5.2 s | ∞ |
✅ Real-World Impact: Training GPT-3 scale models is 2-4× faster with FlashAttention. Llama 2, GPT-4, and Claude all use variants of this technique.
Using FlashAttention in Practice
import torchfrom torch.nn.functional import scaled_dot_product_attention# Modern PyTorch (2.0+) has Flash Attention built-in!def efficient_attention(query, key, value, mask=None): """ Uses FlashAttention automatically if: 1. CUDA is available 2. Data types are FP16 or BF16 3. No custom attention mask (or simple causal mask) """ # PyTorch automatically dispatches to FlashAttention output = scaled_dot_product_attention( query, key, value, attn_mask=mask, dropout_p=0.0, is_causal=True # Enables causal masking efficiently ) return output# Example usagebatch, heads, seq_len, head_dim = 4, 12, 4096, 64Q = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)K = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)V = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)# This uses FlashAttention under the hoodoutput = efficient_attention(Q, K, V)print(f"Output shape: {output.shape}") # [4, 12, 4096, 64]💡 Pro Tip: Always use torch.compile() with FlashAttention for an additional 10-20% speedup from kernel fusion.
2. Positional Encodings: RoPE vs ALiBi vs Learned
Standard Transformers use sinusoidal positional encodings, but modern LLMs have moved to more sophisticated approaches that enable length generalization.
The Position Encoding Problem
Transformers are permutation invariant-without position information, "dog bites man" = "man bites dog". We need to inject position info, but how?
Requirements for Modern LLMs:
- ✅ Extrapolation: Handle sequences longer than training (train on 2K, infer on 8K)
- ✅ Efficiency: No added memory overhead
- ✅ Relative positions: "word 3 tokens ago" matters more than "word at position 947"
RoPE (Rotary Position Embedding)
Used by: Llama, Mistral, Qwen, most modern LLMs
Key Idea: Encode position by rotating the query and key vectors in a specific way that makes dot product depend on relative position.
Mathematical Foundation
For position m, rotate the embedding dimensions by angle mθ:
For dimension pair (d₁, d₂), position m:[x_d₁] [cos(mθ) -sin(mθ)] [x_d₁][x_d₂] = [sin(mθ) cos(mθ)] [x_d₂]Why This Works:
After applying RoPE to queries and keys, their dot product becomes:
Q_m · K_n = f(Q, K, m - n)Notice: Depends only on relative position (m - n), not absolute positions!
Implementation
import torchimport torch.nn as nnclass RotaryPositionEmbedding(nn.Module): """ Rotary Position Embedding (RoPE) implementation Used in LLaMA, Mistral, and most modern LLMs """ def __init__(self, dim, max_seq_len=8192, base=10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base # Precompute theta values for each dimension pair # θᵢ = base^(-2i/d) where i ∈ [0, d/2) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) # Precompute rotation matrices for all positions self._precompute_freqs(max_seq_len) def _precompute_freqs(self, seq_len): """Precompute cos and sin for all positions""" # Positions: [0, 1, 2, ..., seq_len-1] t = torch.arange(seq_len, dtype=self.inv_freq.dtype) # Outer product: [seq_len, dim/2] freqs = torch.outer(t, self.inv_freq) # Combine to get [seq_len, dim] emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer('cos_cached', emb.cos()) self.register_buffer('sin_cached', emb.sin()) def rotate_half(self, x): """ Rotate half the dimensions [x1, x2, x3, x4] → [-x3, -x4, x1, x2] """ x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def forward(self, q, k, seq_len=None): """ Apply rotary embeddings to queries and keys Args: q: [batch, heads, seq_len, head_dim] k: [batch, heads, seq_len, head_dim] Returns: q_rotated, k_rotated (same shape) """ if seq_len is None: seq_len = q.size(2) # Get cached cos/sin for this sequence length cos = self.cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0) sin = self.sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0) # Apply rotation # q_rotated = q * cos + rotate_half(q) * sin q_rotated = q * cos + self.rotate_half(q) * sin k_rotated = k * cos + self.rotate_half(k) * sin return q_rotated, k_rotated# Usage examplerope = RotaryPositionEmbedding(dim=64, max_seq_len=8192)# Sample query and key tensorsbatch, heads, seq_len, head_dim = 2, 8, 512, 64q = torch.randn(batch, heads, seq_len, head_dim)k = torch.randn(batch, heads, seq_len, head_dim)# Apply RoPEq_rotated, k_rotated = rope(q, k)print(f"Rotated Q shape: {q_rotated.shape}") # [2, 8, 512, 64]RoPE Advantages:
- ✅ No learnable parameters (zero overhead)
- ✅ Relative position encoding (better generalization)
- ✅ Excellent extrapolation (can extend context 2-4× beyond training)
- ✅ Fast computation (just element-wise multiplication)
ALiBi (Attention with Linear Biases)
Used by: BLOOM, MPT, some research models
Key Idea: Instead of modifying Q/K, add a linear bias to attention scores based on distance.
Attention_scores[i, j] = Q[i] · K[j] - m × |i - j|Where m is a head-specific slope (different for each attention head).
Implementation
def get_alibi_slopes(num_heads): """ Compute ALiBi slopes for each attention head Uses geometric sequence: 2^(-8/n), 2^(-16/n), ..., 2^(-8) """ def get_slopes_power_of_2(n): start = 2 ** (-8) ratio = start return [start * (ratio ** i) for i in range(n)] if (num_heads & (num_heads - 1)) == 0: # Power of 2 return get_slopes_power_of_2(num_heads) else: # Not power of 2: interpolate closest_power = 2 ** math.floor(math.log2(num_heads)) return ( get_slopes_power_of_2(closest_power) + get_alibi_slopes(2 * closest_power)[0::2][:num_heads - closest_power] )def apply_alibi_bias(attention_scores, num_heads): """ Apply ALiBi bias to attention scores Args: attention_scores: [batch, heads, seq_len, seq_len] Returns: Biased scores (same shape) """ seq_len = attention_scores.size(-1) # Create position distance matrix # distances[i, j] = |i - j| positions = torch.arange(seq_len, device=attention_scores.device) distances = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() # Get slopes for each head slopes = torch.tensor( get_alibi_slopes(num_heads), device=attention_scores.device ).view(1, num_heads, 1, 1) # Compute bias: -slope × distance bias = -slopes * distances.unsqueeze(0).unsqueeze(0) # Add bias to scores return attention_scores + bias# Example usagebatch, heads, seq_len = 2, 12, 512scores = torch.randn(batch, heads, seq_len, seq_len)# Apply ALiBibiased_scores = apply_alibi_bias(scores, num_heads=heads)print(f"Biased scores shape: {biased_scores.shape}") # [2, 12, 512, 512]ALiBi vs RoPE:
| Feature | RoPE | ALiBi | |---------|------|-------| | Memory Overhead | None | Bias matrix (small) | | Computation | Rotate Q/K | Add bias to scores | | Extrapolation | Excellent (4× training length) | Good (2× training length) | | Adoption | Most modern LLMs | Some research models | | FlashAttention Compatible | ✅ Yes | ⚠️ Partial (custom kernels needed) |
💡 When to Use Which: RoPE is the industry standard for good reason. Use ALiBi only if you need extreme simplicity or have specific research requirements.
3. KV-Cache: Making Inference Fast
During autoregressive generation (producing tokens one-by-one), computing attention naively is extremely wasteful.
The Inefficiency Problem
When generating token t, we need to attend to all previous tokens [0, 1, ..., t-1]. Without caching:
# Inefficient autoregressive generationdef generate_naive(model, prompt_ids, max_new_tokens=100): """ Recomputes Q, K, V for ALL tokens at EVERY step For 100 tokens: ~5000 redundant computations! """ input_ids = prompt_ids.clone() for _ in range(max_new_tokens): # Process entire sequence every time (wasteful!) logits = model(input_ids) # Recomputes K, V for all past tokens next_token = logits[:, -1, :].argmax(dim=-1) input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) return input_idsCost: For sequence length n, generating m new tokens costs O(m × n × d²) FLOPs.
KV-Cache Solution
Key Insight: Keys and Values for past tokens never change! Cache them and only compute K, V for the new token.
def generate_with_kv_cache(model, prompt_ids, max_new_tokens=100): """ Efficient generation with KV-cache Reduces computation by ~30-50× for long sequences """ input_ids = prompt_ids.clone() kv_cache = None # Will store past K, V tensors for i in range(max_new_tokens): if i == 0: # First step: process full prompt logits, kv_cache = model(input_ids, use_cache=True, past_kv=None) else: # Subsequent steps: only process new token logits, kv_cache = model( input_ids[:, -1:], # Only last token! use_cache=True, past_kv=kv_cache # Reuse cached K, V ) next_token = logits[:, -1, :].argmax(dim=-1) input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) return input_idsKV-Cache Anatomy
The cache stores Keys and Values for each layer and head:
class TransformerLayerWithCache(nn.Module): """ Transformer layer with KV-caching support """ def __init__(self, d_model=768, num_heads=12): super().__init__() self.num_heads = num_heads self.head_dim = d_model // num_heads self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x, past_kv=None): """ Args: x: [batch, seq_len, d_model] (seq_len=1 when using cache) past_kv: Tuple of (past_key, past_value) or None past_key: [batch, num_heads, past_seq_len, head_dim] Returns: output: [batch, seq_len, d_model] new_kv: Updated (key, value) cache """ batch, seq_len, d_model = x.shape # Compute Q, K, V for current token(s) Q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # If we have past KV, concatenate with current if past_kv is not None: past_key, past_value = past_kv K = torch.cat([past_key, K], dim=2) # Concatenate along sequence dimension V = torch.cat([past_value, V], dim=2) # Store new KV cache (includes past + current) new_kv = (K, V) # Compute attention with full key/value sequence # Q: [batch, heads, 1, head_dim] (current token) # K, V: [batch, heads, total_seq_len, head_dim] (all past + current) attn_output = F.scaled_dot_product_attention(Q, K, V, is_causal=False) # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, d_model) output = self.out_proj(attn_output) return output, new_kv# Example: Generate with KV-cachelayer = TransformerLayerWithCache(d_model=768, num_heads=12)# Step 1: Process prompt (seq_len=10)prompt = torch.randn(1, 10, 768)output1, kv_cache = layer(prompt, past_kv=None)print(f"Step 1 - Output: {output1.shape}, Cache K shape: {kv_cache[0].shape}")# Output: [1, 10, 768], Cache K: [1, 12, 10, 64]# Step 2: Generate next token (seq_len=1, reuse cache)new_token = torch.randn(1, 1, 768)output2, kv_cache = layer(new_token, past_kv=kv_cache)print(f"Step 2 - Output: {output2.shape}, Cache K shape: {kv_cache[0].shape}")# Output: [1, 1, 768], Cache K: [1, 12, 11, 64] ← cache grew by 1Memory Overhead of KV-Cache
For a model with:
Llayershattention headsd_hhead dimension- Sequence length
n - Batch size
b - Data type: FP16 (2 bytes)
KV-Cache Size:
Memory = 2 × L × b × h × n × d_h × 2 bytes ↑ ↑ ↑ ↑ ↑ ↑ ↑ K+V layers batch heads seq head_dim FP16Example (Llama 2 7B):
- L=32 layers, h=32 heads, d_h=128, n=4096
- Memory = 2 × 32 × 1 × 32 × 4096 × 128 × 2 = 2.1 GB per sample
⚠️ Critical Implication: With batch size 8 at 4K context, KV-cache alone uses 17 GB. This is why inference requires so much VRAM!
4. Context Extension Techniques
Training on 4K tokens but need to handle 32K at inference? Here's how modern LLMs do it.
Position Interpolation
Problem: RoPE trained on max length 2048 fails at 4096 because rotation angles are too large.
Solution: Scale down the position indices during inference.
def extend_rope_context(rope_module, original_max_len, new_max_len): """ Extend RoPE context by interpolating position indices Used in Llama 2 Long, Code Llama 100K """ scale = original_max_len / new_max_len # e.g., 2048 / 8192 = 0.25 # Adjust frequency computation # Original: θᵢ = base^(-2i/d) # Extended: θᵢ = base^(-2i/d) × scale rope_module.inv_freq = rope_module.inv_freq * scale # Recompute cached cos/sin with new frequencies rope_module._precompute_freqs(new_max_len) return rope_module# Example: Extend from 2K to 8K contextrope = RotaryPositionEmbedding(dim=64, max_seq_len=2048)rope_extended = extend_rope_context(rope, original_max_len=2048, new_max_len=8192)# Now can handle 8K sequences without retraining!q = torch.randn(1, 8, 8192, 64) # 8K sequencek = torch.randn(1, 8, 8192, 64)q_rot, k_rot = rope_extended(q, k)print(f"Extended to {q_rot.size(2)} tokens") # 8192Performance: Can extend 2-4× with minimal quality loss. Beyond that, fine-tuning is needed.
YaRN (Yet another RoPE extensioN)
More sophisticated interpolation that preserves high-frequency components:
def yarn_scaling(original_length, new_length, dim, base=10000): """ YaRN scaling: Non-uniform interpolation - Low frequencies (long-range): Interpolate aggressively - High frequencies (short-range): Minimal interpolation """ scale = new_length / original_length # Compute per-dimension scaling factors dim_range = torch.arange(0, dim, 2).float() freqs = base ** (-dim_range / dim) # Lower frequencies get more scaling (they capture long-range) # Higher frequencies get less scaling (they capture local patterns) yarn_scale = torch.where( freqs < 0.1, # Low frequency threshold scale, # Full interpolation 1.0 + (scale - 1.0) * (freqs / 0.1) # Gradual interpolation ) return yarn_scale5. Long-Sequence Pitfalls & Solutions
Pitfall 1: Lost in the Middle
Problem: Models pay less attention to middle tokens in very long contexts.
# Experiment: Where does the model look?def test_context_recall(model, tokenizer, needle_position='middle'): """ Test if model can find info at different positions "Needle in haystack" benchmark """ context = "Random text... " * 1000 # Long distractor needle = "The secret password is BANANA." if needle_position == 'start': full_text = needle + context elif needle_position == 'middle': mid = len(context) // 2 full_text = context[:mid] + needle + context[mid:] else: # end full_text = context + needle # Ask model to recall the password prompt = full_text + "What is the secret password?" response = model.generate(tokenizer.encode(prompt)) return "BANANA" in response# Results (GPT-3.5 on 16K context):# Start: 95% recall ✅# Middle: 62% recall ⚠️# End: 98% recall ✅Solutions:
- Structured prompting: Put critical info at start/end
- Retrieval augmentation: Don't put everything in context
- Fine-tune on long documents: Train model to use middle content
Pitfall 2: Attention Collapse
At very long contexts, attention can become too uniform (attends equally to everything = attends to nothing).
def diagnose_attention_collapse(attention_weights): """ Check if attention has collapsed to uniform distribution Args: attention_weights: [batch, heads, seq_len, seq_len] Returns: entropy: Higher = more diffuse attention """ # Compute entropy of attention distribution # Uniform distribution has max entropy eps = 1e-10 entropy = -(attention_weights * torch.log(attention_weights + eps)).sum(dim=-1) max_entropy = math.log(attention_weights.size(-1)) # Ratio: 1.0 = completely uniform (collapsed) entropy_ratio = entropy / max_entropy return entropy_ratio.mean().item()# Exampleattn = torch.softmax(torch.randn(1, 12, 4096, 4096), dim=-1)collapse_score = diagnose_attention_collapse(attn)print(f"Attention collapse score: {collapse_score:.3f}")# > 0.9 indicates potential collapseSolution: Use attention sink tokens (keep first tokens always attended to prevent collapse).
Pitfall 3: Memory Overflow
KV-cache grows linearly with sequence length-can easily OOM.
Solutions:
- PagedAttention (vLLM): Store KV-cache in paged memory, like OS virtual memory
- Streaming LLM: Evict old cache, keep only recent + first few tokens
- Compress KV-cache: Quantize to INT8 or lower precision
def streaming_kv_cache(kv_cache, max_cache_len=2048, keep_first=128): """ Maintain fixed-size KV-cache for infinite generation Keep first N tokens (attention sinks) + recent tokens """ key_cache, value_cache = kv_cache current_len = key_cache.size(2) if current_len <= max_cache_len: return kv_cache # No eviction needed # Keep first keep_first tokens + most recent recent_tokens = max_cache_len - keep_first key_cache = torch.cat([ key_cache[:, :, :keep_first, :], # First N tokens key_cache[:, :, -recent_tokens:, :] # Recent tokens ], dim=2) value_cache = torch.cat([ value_cache[:, :, :keep_first, :], value_cache[:, :, -recent_tokens:, :] ], dim=2) return (key_cache, value_cache)Summary: Key Takeaways
FlashAttention
- ✅ Breakthrough: Computes attention without materializing n² matrix
- ✅ How: Tiling + online softmax in fast SRAM
- ✅ Impact: 5-20× faster, handles 4× longer sequences
- ✅ Use: Built into PyTorch 2.0+
scaled_dot_product_attention
Position Encodings
- ✅ RoPE: Industry standard, excellent extrapolation, zero overhead
- ✅ ALiBi: Simple linear biases, good for research
- ✅ Choose RoPE unless you have specific needs
KV-Cache
- ✅ Purpose: Avoid recomputing past tokens' K, V during generation
- ✅ Speedup: 30-50× faster inference
- ✅ Cost: 2-4 GB memory per sample at 4K context
- ✅ Critical for production inference
Context Extension
- ✅ Position Interpolation: Scale RoPE frequencies for 2-4× extension
- ✅ YaRN: Non-uniform scaling preserves quality better
- ✅ Limitation: Beyond 4× requires fine-tuning
Long-Sequence Pitfalls
- ⚠️ Lost in the Middle: Models struggle with middle content
- ⚠️ Attention Collapse: Uniform attention at extreme lengths
- ⚠️ Memory: KV-cache scales linearly with length
- ✅ Solutions: Structured prompts, streaming cache, retrieval augmentation
🎯 Next Steps: Apply these to build efficient Transformer variants and explore State Space Models (SSMs) that break the n² barrier entirely!
Lesson Content
Master advanced Transformer optimization techniques including FlashAttention's memory-efficient attention, positional encoding strategies (RoPE/ALiBi), KV-cache mechanisms, context extension methods, and pitfalls of long-sequence processing.
Code Example
# Complete implementation: Efficient Transformer with all optimizationsimport torchimport torch.nn as nnimport torch.nn.functional as Fimport mathclass RoPE(nn.Module): """Rotary Position Embedding - Production ready""" def __init__(self, dim, max_seq_len=32768, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self._precompute_freqs(max_seq_len) def _precompute_freqs(self, seq_len): t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer('cos_cached', emb.cos()) self.register_buffer('sin_cached', emb.sin()) def rotate_half(self, x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def forward(self, q, k): seq_len = q.size(2) cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) q_rotated = q * cos + self.rotate_half(q) * sin k_rotated = k * cos + self.rotate_half(k) * sin return q_rotated, k_rotatedclass EfficientAttention(nn.Module): """ Multi-head attention with all efficiency features: - FlashAttention (via PyTorch SDPA) - RoPE positional encoding - KV-cache support """ def __init__(self, d_model=768, num_heads=12, max_seq_len=8192): super().__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.num_heads = num_heads self.head_dim = d_model // num_heads self.scale = self.head_dim ** -0.5 # Projections self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) # RoPE for position encoding self.rope = RoPE(self.head_dim, max_seq_len=max_seq_len) def forward(self, x, past_kv=None, use_cache=False): """ Args: x: [batch, seq_len, d_model] past_kv: Optional tuple of (past_key, past_value) for caching use_cache: Whether to return updated KV cache Returns: output: [batch, seq_len, d_model] new_kv: Updated cache if use_cache=True, else None """ batch, seq_len, d_model = x.shape # Compute Q, K, V Q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Apply RoPE to Q and K Q, K = self.rope(Q, K) # Handle KV-cache if past_kv is not None: past_key, past_value = past_kv K = torch.cat([past_key, K], dim=2) V = torch.cat([past_value, V], dim=2) # Prepare cache for next iteration new_kv = (K, V) if use_cache else None # FlashAttention (automatically used if available) # is_causal=True for decoder-only models (like GPT) attn_output = F.scaled_dot_product_attention( Q, K, V, attn_mask=None, dropout_p=0.0, is_causal=(past_kv is None) # Only causal on first pass ) # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch, seq_len, d_model) output = self.out_proj(attn_output) return output, new_kvclass TransformerBlock(nn.Module): """Complete Transformer block with efficient attention""" def __init__(self, d_model=768, num_heads=12, d_ff=3072, dropout=0.1): super().__init__() self.attn = EfficientAttention(d_model, num_heads) self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) # Feed-forward network self.ff = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model), nn.Dropout(dropout) ) def forward(self, x, past_kv=None, use_cache=False): # Attention with residual attn_out, new_kv = self.attn(self.ln1(x), past_kv, use_cache) x = x + attn_out # Feed-forward with residual x = x + self.ff(self.ln2(x)) return x, new_kvclass EfficientTransformer(nn.Module): """ Production-ready efficient Transformer with: - FlashAttention - RoPE - KV-cache - Context extension support """ def __init__( self, vocab_size=50257, d_model=768, num_layers=12, num_heads=12, d_ff=3072, max_seq_len=8192, dropout=0.1 ): super().__init__() self.d_model = d_model self.max_seq_len = max_seq_len # Embeddings self.token_emb = nn.Embedding(vocab_size, d_model) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.ln_final = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) def forward(self, input_ids, past_kvs=None, use_cache=False): """ Args: input_ids: [batch, seq_len] past_kvs: List of (key, value) tuples for each layer use_cache: Whether to return KV cache Returns: logits: [batch, seq_len, vocab_size] new_kvs: List of updated caches if use_cache=True """ x = self.token_emb(input_ids) new_kvs = [] if use_cache else None for i, block in enumerate(self.blocks): past_kv = past_kvs[i] if past_kvs is not None else None x, new_kv = block(x, past_kv, use_cache) if use_cache: new_kvs.append(new_kv) x = self.ln_final(x) logits = self.head(x) return logits, new_kvs @torch.no_grad() def generate(self, input_ids, max_new_tokens=100, temperature=1.0): """ Efficient autoregressive generation with KV-cache """ past_kvs = None for _ in range(max_new_tokens): # Only pass last token after first iteration curr_input = input_ids if past_kvs is None else input_ids[:, -1:] # Forward pass with caching logits, past_kvs = self.forward(curr_input, past_kvs, use_cache=True) # Sample next token next_token_logits = logits[:, -1, :] / temperature next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=1) # Optional: check for EOS token # if next_token.item() == eos_token_id: # break return input_ids# ============================================================================# Usage Examples# ============================================================================def example_basic_usage(): """Basic model creation and forward pass""" model = EfficientTransformer( vocab_size=50257, d_model=768, num_layers=12, num_heads=12, max_seq_len=8192 ).cuda() # Single forward pass input_ids = torch.randint(0, 50257, (2, 512)).cuda() logits, _ = model(input_ids, use_cache=False) print(f"Logits shape: {logits.shape}") # [2, 512, 50257]def example_efficient_generation(): """Efficient generation with KV-cache""" model = EfficientTransformer( vocab_size=50257, d_model=768, num_layers=12, num_heads=12, max_seq_len=8192 ).cuda().eval() # Prompt prompt_ids = torch.randint(0, 50257, (1, 128)).cuda() # Generate 100 tokens efficiently with torch.no_grad(): output_ids = model.generate(prompt_ids, max_new_tokens=100) print(f"Generated {output_ids.size(1) - 128} new tokens") print(f"Final sequence length: {output_ids.size(1)}")def example_context_extension(): """Extend context length beyond training""" # Train on 2K context model_2k = EfficientTransformer(max_seq_len=2048).cuda() # Extend to 8K using position interpolation def extend_context(model, new_max_len): for block in model.blocks: rope = block.attn.rope scale = rope.max_seq_len / new_max_len rope.inv_freq = rope.inv_freq * scale rope._precompute_freqs(new_max_len) rope.max_seq_len = new_max_len model.max_seq_len = new_max_len extend_context(model_2k, new_max_len=8192) # Now can handle 8K sequences long_input = torch.randint(0, 50257, (1, 8192)).cuda() logits, _ = model_2k(long_input) print(f"Successfully processed {long_input.size(1)} tokens")def benchmark_kv_cache_speedup(): """Measure KV-cache speedup""" import time model = EfficientTransformer( vocab_size=1000, d_model=512, num_layers=6, num_heads=8, max_seq_len=2048 ).cuda().eval() prompt = torch.randint(0, 1000, (1, 128)).cuda() num_tokens = 50 # Without KV-cache (naive) start = time.time() with torch.no_grad(): input_ids = prompt.clone() for _ in range(num_tokens): logits, _ = model(input_ids, use_cache=False) next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1) naive_time = time.time() - start # With KV-cache start = time.time() with torch.no_grad(): output_ids = model.generate(prompt, max_new_tokens=num_tokens) cache_time = time.time() - start print(f"\nGeneration Benchmark ({num_tokens} tokens):") print(f"Without KV-cache: {naive_time:.3f}s") print(f"With KV-cache: {cache_time:.3f}s") print(f"Speedup: {naive_time / cache_time:.2f}×")if __name__ == "__main__": print("=== Example 1: Basic Usage ===") example_basic_usage() print("\n=== Example 2: Efficient Generation ===") example_efficient_generation() print("\n=== Example 3: Context Extension ===") example_context_extension() print("\n=== Example 4: KV-Cache Speedup ===") benchmark_kv_cache_speedup()