AI Intermediate: Deep Learning & Neural Networks
Autograd & Computational Graphs: From Theory to Debugging
Theory & Concepts
Understanding Autograd: The Engine Behind Deep Learning
Automatic differentiation (autograd) is the cornerstone of modern deep learning. It enables neural networks to learn by automatically computing gradients of complex functions-no manual calculus required.
💡 Why This Matters: Without autograd, training even a simple 3-layer neural network would require hours of manual derivative calculations. Autograd does this in milliseconds, making deep learning practical.
What is Autograd?
Automatic Differentiation is NOT:
- ❌ Symbolic differentiation (like Wolfram Alpha does)
- ❌ Numerical differentiation (finite differences)
- ✅ It's a algorithmic technique that applies the chain rule automatically
Key Concept: Autograd builds a computational graph as your code executes, tracking every operation. Then it traverses this graph backwards to compute gradients efficiently.
Computational Graphs Explained
Think of a computational graph as a blueprint of your computation:
Nodes: Variables and operations Edges: Data flow between operations
Simple Example: f(x, y) = (x + y) × x
Input: x=2, y=3 Graph: x=2 y=3 \ / (+) → a=5 | \ | x=2 | / (×) → f=10Forward Pass: Compute f(x, y) = 10 Backward Pass: Compute ∂f/∂x and ∂f/∂y using chain rule
The Chain Rule: Foundation of Backpropagation
For composite functions, the chain rule states:
If z = f(y) and y = g(x), then:
dz/dx = (dz/dy) × (dy/dx)Multi-variable version (used in neural networks):
∂L/∂w = Σᵢ (∂L/∂yᵢ) × (∂yᵢ/∂w)Where:
- L = Loss function (what we want to minimize)
- w = Weight parameter
- yᵢ = Intermediate values in the computation graph
ℹ️ Intuition: The gradient at each node is the sum of all paths from that node to the output, multiplied together via chain rule.
PyTorch vs JAX: Two Approaches to Autograd
PyTorch: Define-by-Run (Dynamic Graphs)
- Graph is built as code executes
- Different graph for each forward pass
- Easier to debug (feels like normal Python)
- Perfect for variable-length sequences, conditional logic
JAX: JIT Compilation (Static Graphs)
- Graph is traced once, then compiled
- Same graph for all inputs (must be same shape)
- Blazing fast execution (XLA compilation)
- Better for production deployment
⚠️ Critical Difference: PyTorch rebuilds the graph every iteration. JAX traces it once. Choose based on your flexibility vs speed needs.
Common Pitfalls & Debugging Strategy
1. Detached Tensors (No Gradient Flow)
Problem: Operations that break the computational graph
x = torch.tensor([2.0], requires_grad=True)y = x.detach() # ❌ Gradient won't flow through yz = y ** 2z.backward() # x.grad will be None!Fix: Avoid .detach(), .numpy(), or .item() in the middle of computation
2. In-Place Operations
Problem: Modifying tensors in-place can corrupt gradients
x = torch.tensor([2.0], requires_grad=True)y = x ** 2x += 1 # ❌ In-place modificationy.backward() # RuntimeError!Fix: Use x = x + 1 instead of x += 1
3. Shape Mismatches
Problem: Broadcasting can hide shape errors until backprop
# Forward pass works (broadcasts automatically)x = torch.randn(32, 10) # Batch of 32 samplesw = torch.randn(10, requires_grad=True) # ❌ Should be (10, 1)y = x * w # Works but gives wrong shape (32, 10)# Backward pass crashes or gives wrong gradientsFix: Always verify tensor shapes with .shape before and after operations
4. NaN/Inf in Gradients
Most common causes:
- Division by zero:
1 / (x - x) - Log of zero/negative:
log(0)orlog(-1) - Exploding gradients: Very large learning rates
- Numerical overflow:
exp(1000)
⚠️ Critical Debugging Tip: Use torch.autograd.set_detect_anomaly(True) to pinpoint the exact operation that produces NaN.
Gradient Checking: Verifying Your Autograd
Always verify autograd implementation with numerical gradients:
Finite Difference Approximation:
f'(x) ≈ [f(x + ε) - f(x - ε)] / (2ε)Where ε is a small value (e.g., 1e-5)
If autograd gradient ≈ numerical gradient (within 1e-5), you're good!
When to Use Manual Gradients
Most of the time, use autograd. But manual gradients are needed for:
- Custom CUDA kernels (low-level GPU operations)
- Non-differentiable operations (argmax, sampling)
- Memory-constrained scenarios (gradient checkpointing)
Mental Model: Autograd as a Recording System
Think of autograd as a video recorder:
- Press Record (
requires_grad=True) - Perform operations (forward pass) - everything is recorded
- Play backwards (
.backward()) - gradients computed from recording - Get the gradients (
.grad) - extract what was computed
Summary
Key Takeaways:
- Autograd builds a computational graph during forward pass
- Backward pass applies chain rule automatically
- PyTorch = dynamic (flexible), JAX = static (fast)
- Common bugs: detached tensors, in-place ops, shape mismatches, NaN/Infs
- Use anomaly detection and gradient checking for debugging
- The chain rule is the mathematical foundation of everything
Remember: Master autograd debugging = Master deep learning implementation!
Lesson Content
Master automatic differentiation (autograd) and computational graphs in PyTorch and JAX. Learn practical debugging techniques for gradients, shape mismatches, and numerical instabilities (NaNs/Infs).
Code Example
# Autograd & Computational Graphs: Complete Practical Guide# From basics to advanced debugging techniquesimport torchimport torch.nn as nnimport numpy as npfrom typing import Tupleprint("="*90)print("AUTOGRAD & COMPUTATIONAL GRAPHS: PRACTICAL DEBUGGING GUIDE")print("="*90)print()# =============================================================================# PART 1: Understanding Computational Graphs# =============================================================================print("1. COMPUTATIONAL GRAPH VISUALIZATION")print("-" * 90)# Simple example: f(x, y) = (x + y) * xx = torch.tensor(2.0, requires_grad=True)y = torch.tensor(3.0, requires_grad=True)# Forward pass (graph is built automatically)a = x + y # Intermediate nodef = a * x # Output nodeprint(f"Forward Pass:")print(f" x = {x.item()}")print(f" y = {y.item()}")print(f" a = x + y = {a.item()}")print(f" f = a * x = {f.item()}")print()# Backward pass (traverse graph backwards)f.backward()print(f"Backward Pass (Gradients):")print(f" ∂f/∂x = {x.grad.item():.4f}")print(f" ∂f/∂y = {y.grad.item():.4f}")print()# Verify with manual calculationprint(f"Manual Verification:")print(f" ∂f/∂x = ∂/∂x[(x+y)×x] = (x+y) + x = {(x.item() + y.item()) + x.item()}")print(f" ∂f/∂y = ∂/∂y[(x+y)×x] = x = {x.item()}")print(f" ✓ Autograd matches manual calculation!")print()# =============================================================================# PART 2: Tracking Gradient Flow# =============================================================================print("2. GRADIENT FLOW TRACKING")print("-" * 90)def demonstrate_gradient_flow(): """Show how gradients flow through operations""" # Create input with gradient tracking x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) # Complex computation y = x ** 2 # Element-wise square z = y.mean() # Average print(f"Forward:") print(f" x = {x.data.numpy()}") print(f" y = x² = {y.data.numpy()}") print(f" z = mean(y) = {z.item():.4f}") print() # Check what requires gradients print(f"Gradient Tracking Status:") print(f" x.requires_grad = {x.requires_grad}") print(f" y.requires_grad = {y.requires_grad}") print(f" z.requires_grad = {z.requires_grad}") print() # Compute gradients z.backward() print(f"Gradients:") print(f" dz/dx = {x.grad.numpy()}") print() # Manual verification # z = mean(x²) = (x₁² + x₂² + x₃²) / 3 # dz/dx₁ = 2x₁/3 manual_grad = 2 * x.data.numpy() / 3 print(f"Manual Calculation: dz/dx = 2x/3 = {manual_grad}") print(f" ✓ Match: {np.allclose(x.grad.numpy(), manual_grad)}") print()demonstrate_gradient_flow()# =============================================================================# PART 3: Common Bug #1 - Detached Tensors# =============================================================================print("3. DEBUGGING: DETACHED TENSORS (Broken Gradient Flow)")print("-" * 90)def bug_detached_tensor(): """Demonstrate the detached tensor bug""" print("❌ WRONG: Using .detach() breaks gradient flow") x = torch.tensor([2.0], requires_grad=True) y = x.detach() # Breaks the chain! z = y ** 2 z.backward() print(f" x.grad = {x.grad} # None - gradient didn't flow!") print() print("✓ CORRECT: Keep tensor attached") x = torch.tensor([2.0], requires_grad=True) y = x # No detach z = y ** 2 z.backward() print(f" x.grad = {x.grad.item()} # 4.0 - gradient flows correctly!") print()bug_detached_tensor()# =============================================================================# PART 4: Common Bug #2 - In-Place Operations# =============================================================================print("4. DEBUGGING: IN-PLACE OPERATIONS")print("-" * 90)def bug_inplace_operation(): """Demonstrate in-place operation bug""" print("❌ WRONG: In-place modification corrupts gradients") try: x = torch.tensor([2.0], requires_grad=True) y = x ** 2 x += 1 # In-place modification y.backward() except RuntimeError as e: print(f" RuntimeError: {str(e)[:80]}...") print() print("✓ CORRECT: Create new tensor instead") x = torch.tensor([2.0], requires_grad=True) y = x ** 2 x_new = x + 1 # Creates new tensor y.backward() print(f" x.grad = {x.grad.item()} # Works!") print()bug_inplace_operation()# =============================================================================# PART 5: Common Bug #3 - Shape Mismatches# =============================================================================print("5. DEBUGGING: SHAPE MISMATCHES")print("-" * 90)def debug_shape_mismatch(): """Demonstrate shape debugging techniques""" print("❌ WRONG: Shape mismatch hidden by broadcasting") batch_size = 32 features = 10 x = torch.randn(batch_size, features, requires_grad=True) w = torch.randn(features, requires_grad=True) # Wrong shape! # Forward pass works due to broadcasting y = x * w # Shape becomes (32, 10) - probably not intended! loss = y.sum() print(f" x.shape = {x.shape}") print(f" w.shape = {w.shape}") print(f" y.shape = {y.shape} # Broadcasting happened!") print() loss.backward() print(f" w.grad.shape = {w.grad.shape} # Gradient shape might be wrong") print() print("✓ CORRECT: Explicit shape checking") x = torch.randn(batch_size, features, requires_grad=True) w = torch.randn(features, 1, requires_grad=True) # Correct shape # Add assertion for shape validation assert w.shape == (features, 1), f"Expected ({features}, 1), got {w.shape}" y = x * w loss = y.sum() print(f" x.shape = {x.shape}") print(f" w.shape = {w.shape}") print(f" y.shape = {y.shape} # Shape is now correct") print() loss.backward() print(f" w.grad.shape = {w.grad.shape} # Gradient shape matches parameter") print()debug_shape_mismatch()# =============================================================================# PART 6: Common Bug #4 - NaN/Inf Detection# =============================================================================print("6. DEBUGGING: NaN/Inf DETECTION & PREVENTION")print("-" * 90)def detect_nan_gradients(): """Demonstrate NaN detection techniques""" print("❌ Operations that cause NaN:") # Example 1: Division by zero x = torch.tensor([1.0], requires_grad=True) y = 1.0 / (x - x) # Division by zero! print(f" 1/(x-x) = {y.item()}") # inf # Example 2: Log of zero x = torch.tensor([0.0], requires_grad=True) y = torch.log(x) # log(0) = -inf print(f" log(0) = {y.item()}") # Example 3: Square root of negative x = torch.tensor([-1.0], requires_grad=True) y = torch.sqrt(x) # sqrt(-1) = nan print(f" sqrt(-1) = {y.item()}") print() print("✓ CORRECT: Using anomaly detection") # Enable anomaly detection (slower but helps debugging) torch.autograd.set_detect_anomaly(True) try: x = torch.tensor([1.0], requires_grad=True) y = torch.log(x - 1) # log(0) will cause issues z = y ** 2 z.backward() except RuntimeError as e: print(f" Anomaly detected: {str(e)[:80]}...") torch.autograd.set_detect_anomaly(False) print() print("✓ PREVENTION: Numerical stability techniques") # Technique 1: Add epsilon for numerical stability x = torch.tensor([0.0], requires_grad=True) eps = 1e-8 y = torch.log(x + eps) # Prevents log(0) print(f" log(x + ε) = {y.item():.4f} # Stable!") # Technique 2: Clamp values to safe range x = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True) x_safe = torch.clamp(x, min=eps) # Ensure all values ≥ ε y = torch.log(x_safe) print(f" log(clamp(x)) = {y.data.numpy()} # All finite!") print()detect_nan_gradients()# =============================================================================# PART 7: Gradient Checking (Numerical Verification)# =============================================================================print("7. GRADIENT CHECKING: Verifying Autograd Correctness")print("-" * 90)def gradient_check(func, x: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: """ Verify autograd gradients against numerical gradients Args: func: Function to compute x: Input tensor eps: Small perturbation for finite differences Returns: (analytical_grad, numerical_grad) """ # Analytical gradient (autograd) x.requires_grad = True y = func(x) y.backward() analytical_grad = x.grad.clone() x.grad.zero_() # Numerical gradient (finite differences) numerical_grad = torch.zeros_like(x) for i in range(x.numel()): # f(x + eps) x_plus = x.clone().detach() x_plus.view(-1)[i] += eps f_plus = func(x_plus).item() # f(x - eps) x_minus = x.clone().detach() x_minus.view(-1)[i] -= eps f_minus = func(x_minus).item() # Centered finite difference numerical_grad.view(-1)[i] = (f_plus - f_minus) / (2 * eps) return analytical_grad, numerical_grad# Test on a complex functiondef test_function(x): """Complex function: f(x) = sum(x³ - 2x² + 5x)""" return (x**3 - 2*x**2 + 5*x).sum()x_test = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)analytical, numerical = gradient_check(test_function, x_test)print(f"Function: f(x) = x³ - 2x² + 5x")print(f"At x = {x_test.data.numpy()}")print()print(f"Analytical gradient (autograd): {analytical.numpy()}")print(f"Numerical gradient (finite diff): {numerical.numpy()}")print(f"Difference: {(analytical - numerical).abs().numpy()}")print(f"Relative error: {((analytical - numerical).abs() / (analytical.abs() + 1e-8)).numpy()}")print()if torch.allclose(analytical, numerical, rtol=1e-4, atol=1e-5): print("✓ GRADIENT CHECK PASSED! Autograd is correct.")else: print("❌ GRADIENT CHECK FAILED! Check your implementation.")print()# =============================================================================# PART 8: Advanced - Custom Gradient Function# =============================================================================print("8. ADVANCED: Custom Autograd Functions")print("-" * 90)class MyReLU(torch.autograd.Function): """ Custom ReLU implementation with manual backward pass Demonstrates how to write custom differentiable operations """ @staticmethod def forward(ctx, input): """Forward pass: ReLU(x) = max(0, x)""" ctx.save_for_backward(input) return input.clamp(min=0) @staticmethod def backward(ctx, grad_output): """Backward pass: derivative is 1 if x > 0, else 0""" input, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input# Test custom ReLUx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)y = MyReLU.apply(x)loss = y.sum()loss.backward()print(f"Input: {x.data.numpy()}")print(f"ReLU(x): {y.data.numpy()}")print(f"Gradient: {x.grad.numpy()}")print(f" → Gradient is 1 where x > 0, else 0")print()# =============================================================================# PART 9: Best Practices Summary# =============================================================================print("9. DEBUGGING CHECKLIST")print("-" * 90)print("""Before training any neural network:✓ 1. Print tensor shapes at each layer✓ 2. Verify gradients with gradient checking (at least once)✓ 3. Use torch.autograd.set_detect_anomaly(True) during debugging✓ 4. Check for NaN/Inf after each training step: - if torch.isnan(loss): raise ValueError("NaN loss!")✓ 5. Avoid .detach(), .numpy(), .item() in computation graph✓ 6. Use x = x + 1, not x += 1 (avoid in-place)✓ 7. Add numerical stability (eps=1e-8) to log/div operations✓ 8. Initialize weights properly (Xavier/He initialization)✓ 9. Start with small learning rate (1e-4) and increase gradually✓ 10. Monitor gradient norms: torch.nn.utils.clip_grad_norm_()Common NaN Sources: • log(0), log(negative) • 1/0, x/0 • sqrt(negative) • exp(very_large_number) → inf • Learning rate too high → exploding gradientsQuick NaN Debug: 1. torch.autograd.set_detect_anomaly(True) 2. Add print(f"Loss: {loss.item()}") after each forward 3. Check: torch.isnan(model.weight.grad).any() 4. Reduce learning rate by 10x 5. Add gradient clipping""")print("="*90)print("COMPLETE! You now understand autograd and can debug gradient issues.")print("="*90)