Autograd & Computational Graphs: From Theory to Debugging

40 mintext

Theory & Concepts

Understanding Autograd: The Engine Behind Deep Learning

Automatic differentiation (autograd) is the cornerstone of modern deep learning. It enables neural networks to learn by automatically computing gradients of complex functions-no manual calculus required.

💡 Why This Matters: Without autograd, training even a simple 3-layer neural network would require hours of manual derivative calculations. Autograd does this in milliseconds, making deep learning practical.

What is Autograd?

Automatic Differentiation is NOT:

  • ❌ Symbolic differentiation (like Wolfram Alpha does)
  • ❌ Numerical differentiation (finite differences)
  • It's a algorithmic technique that applies the chain rule automatically

Key Concept: Autograd builds a computational graph as your code executes, tracking every operation. Then it traverses this graph backwards to compute gradients efficiently.

Computational Graphs Explained

Think of a computational graph as a blueprint of your computation:

Nodes: Variables and operations Edges: Data flow between operations

Simple Example: f(x, y) = (x + y) × x

Input: x=2, y=3
 
Graph:
x=2 y=3
\ /
(+) → a=5
| \
| x=2
| /
(×) → f=10

Forward Pass: Compute f(x, y) = 10 Backward Pass: Compute ∂f/∂x and ∂f/∂y using chain rule

The Chain Rule: Foundation of Backpropagation

For composite functions, the chain rule states:

If z = f(y) and y = g(x), then:

dz/dx = (dz/dy) × (dy/dx)

Multi-variable version (used in neural networks):

∂L/∂w = Σᵢ (∂L/∂yᵢ) × (∂yᵢ/∂w)

Where:

  • L = Loss function (what we want to minimize)
  • w = Weight parameter
  • yᵢ = Intermediate values in the computation graph

ℹ️ Intuition: The gradient at each node is the sum of all paths from that node to the output, multiplied together via chain rule.

PyTorch vs JAX: Two Approaches to Autograd

PyTorch: Define-by-Run (Dynamic Graphs)

  • Graph is built as code executes
  • Different graph for each forward pass
  • Easier to debug (feels like normal Python)
  • Perfect for variable-length sequences, conditional logic

JAX: JIT Compilation (Static Graphs)

  • Graph is traced once, then compiled
  • Same graph for all inputs (must be same shape)
  • Blazing fast execution (XLA compilation)
  • Better for production deployment

⚠️ Critical Difference: PyTorch rebuilds the graph every iteration. JAX traces it once. Choose based on your flexibility vs speed needs.

Common Pitfalls & Debugging Strategy

1. Detached Tensors (No Gradient Flow)

Problem: Operations that break the computational graph

python
x = torch.tensor([2.0], requires_grad=True)
y = x.detach() # Gradient won't flow through y
z = y ** 2
z.backward() # x.grad will be None!

Fix: Avoid .detach(), .numpy(), or .item() in the middle of computation

2. In-Place Operations

Problem: Modifying tensors in-place can corrupt gradients

python
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
x += 1 # In-place modification
y.backward() # RuntimeError!

Fix: Use x = x + 1 instead of x += 1

3. Shape Mismatches

Problem: Broadcasting can hide shape errors until backprop

python
# Forward pass works (broadcasts automatically)
x = torch.randn(32, 10) # Batch of 32 samples
w = torch.randn(10, requires_grad=True) # Should be (10, 1)
y = x * w # Works but gives wrong shape (32, 10)
# Backward pass crashes or gives wrong gradients

Fix: Always verify tensor shapes with .shape before and after operations

4. NaN/Inf in Gradients

Most common causes:

  • Division by zero: 1 / (x - x)
  • Log of zero/negative: log(0) or log(-1)
  • Exploding gradients: Very large learning rates
  • Numerical overflow: exp(1000)

⚠️ Critical Debugging Tip: Use torch.autograd.set_detect_anomaly(True) to pinpoint the exact operation that produces NaN.

Gradient Checking: Verifying Your Autograd

Always verify autograd implementation with numerical gradients:

Finite Difference Approximation:

f'(x) ≈ [f(x + ε) - f(x - ε)] / (2ε)

Where ε is a small value (e.g., 1e-5)

If autograd gradient ≈ numerical gradient (within 1e-5), you're good!

When to Use Manual Gradients

Most of the time, use autograd. But manual gradients are needed for:

  • Custom CUDA kernels (low-level GPU operations)
  • Non-differentiable operations (argmax, sampling)
  • Memory-constrained scenarios (gradient checkpointing)

Mental Model: Autograd as a Recording System

Think of autograd as a video recorder:

  1. Press Record (requires_grad=True)
  2. Perform operations (forward pass) - everything is recorded
  3. Play backwards (.backward()) - gradients computed from recording
  4. Get the gradients (.grad) - extract what was computed

Summary

Key Takeaways:

  1. Autograd builds a computational graph during forward pass
  2. Backward pass applies chain rule automatically
  3. PyTorch = dynamic (flexible), JAX = static (fast)
  4. Common bugs: detached tensors, in-place ops, shape mismatches, NaN/Infs
  5. Use anomaly detection and gradient checking for debugging
  6. The chain rule is the mathematical foundation of everything

Remember: Master autograd debugging = Master deep learning implementation!

Lesson Content

Master automatic differentiation (autograd) and computational graphs in PyTorch and JAX. Learn practical debugging techniques for gradients, shape mismatches, and numerical instabilities (NaNs/Infs).

Code Example

python
# Autograd & Computational Graphs: Complete Practical Guide
# From basics to advanced debugging techniques
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple
print("="*90)
print("AUTOGRAD & COMPUTATIONAL GRAPHS: PRACTICAL DEBUGGING GUIDE")
print("="*90)
print()
# =============================================================================
# PART 1: Understanding Computational Graphs
# =============================================================================
print("1. COMPUTATIONAL GRAPH VISUALIZATION")
print("-" * 90)
# Simple example: f(x, y) = (x + y) * x
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
# Forward pass (graph is built automatically)
a = x + y # Intermediate node
f = a * x # Output node
print(f"Forward Pass:")
print(f" x = {x.item()}")
print(f" y = {y.item()}")
print(f" a = x + y = {a.item()}")
print(f" f = a * x = {f.item()}")
print()
# Backward pass (traverse graph backwards)
f.backward()
print(f"Backward Pass (Gradients):")
print(f" ∂f/∂x = {x.grad.item():.4f}")
print(f" ∂f/∂y = {y.grad.item():.4f}")
print()
# Verify with manual calculation
print(f"Manual Verification:")
print(f" ∂f/∂x = /∂x[(x+y)×x] = (x+y) + x = {(x.item() + y.item()) + x.item()}")
print(f" ∂f/∂y = /∂y[(x+y)×x] = x = {x.item()}")
print(f" Autograd matches manual calculation!")
print()
# =============================================================================
# PART 2: Tracking Gradient Flow
# =============================================================================
print("2. GRADIENT FLOW TRACKING")
print("-" * 90)
def demonstrate_gradient_flow():
"""Show how gradients flow through operations"""
# Create input with gradient tracking
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# Complex computation
y = x ** 2 # Element-wise square
z = y.mean() # Average
print(f"Forward:")
print(f" x = {x.data.numpy()}")
print(f" y = = {y.data.numpy()}")
print(f" z = mean(y) = {z.item():.4f}")
print()
# Check what requires gradients
print(f"Gradient Tracking Status:")
print(f" x.requires_grad = {x.requires_grad}")
print(f" y.requires_grad = {y.requires_grad}")
print(f" z.requires_grad = {z.requires_grad}")
print()
# Compute gradients
z.backward()
print(f"Gradients:")
print(f" dz/dx = {x.grad.numpy()}")
print()
# Manual verification
# z = mean(x²) = (x₁² + x₂² + x₃²) / 3
# dz/dx₁ = 2x₁/3
manual_grad = 2 * x.data.numpy() / 3
print(f"Manual Calculation: dz/dx = 2x/3 = {manual_grad}")
print(f" Match: {np.allclose(x.grad.numpy(), manual_grad)}")
print()
demonstrate_gradient_flow()
# =============================================================================
# PART 3: Common Bug #1 - Detached Tensors
# =============================================================================
print("3. DEBUGGING: DETACHED TENSORS (Broken Gradient Flow)")
print("-" * 90)
def bug_detached_tensor():
"""Demonstrate the detached tensor bug"""
print("❌ WRONG: Using .detach() breaks gradient flow")
x = torch.tensor([2.0], requires_grad=True)
y = x.detach() # Breaks the chain!
z = y ** 2
z.backward()
print(f" x.grad = {x.grad} # None - gradient didn't flow!")
print()
print("✓ CORRECT: Keep tensor attached")
x = torch.tensor([2.0], requires_grad=True)
y = x # No detach
z = y ** 2
z.backward()
print(f" x.grad = {x.grad.item()} # 4.0 - gradient flows correctly!")
print()
bug_detached_tensor()
# =============================================================================
# PART 4: Common Bug #2 - In-Place Operations
# =============================================================================
print("4. DEBUGGING: IN-PLACE OPERATIONS")
print("-" * 90)
def bug_inplace_operation():
"""Demonstrate in-place operation bug"""
print("❌ WRONG: In-place modification corrupts gradients")
try:
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
x += 1 # In-place modification
y.backward()
except RuntimeError as e:
print(f" RuntimeError: {str(e)[:80]}...")
print()
print("✓ CORRECT: Create new tensor instead")
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
x_new = x + 1 # Creates new tensor
y.backward()
print(f" x.grad = {x.grad.item()} # Works!")
print()
bug_inplace_operation()
# =============================================================================
# PART 5: Common Bug #3 - Shape Mismatches
# =============================================================================
print("5. DEBUGGING: SHAPE MISMATCHES")
print("-" * 90)
def debug_shape_mismatch():
"""Demonstrate shape debugging techniques"""
print("❌ WRONG: Shape mismatch hidden by broadcasting")
batch_size = 32
features = 10
x = torch.randn(batch_size, features, requires_grad=True)
w = torch.randn(features, requires_grad=True) # Wrong shape!
# Forward pass works due to broadcasting
y = x * w # Shape becomes (32, 10) - probably not intended!
loss = y.sum()
print(f" x.shape = {x.shape}")
print(f" w.shape = {w.shape}")
print(f" y.shape = {y.shape} # Broadcasting happened!")
print()
loss.backward()
print(f" w.grad.shape = {w.grad.shape} # Gradient shape might be wrong")
print()
print("✓ CORRECT: Explicit shape checking")
x = torch.randn(batch_size, features, requires_grad=True)
w = torch.randn(features, 1, requires_grad=True) # Correct shape
# Add assertion for shape validation
assert w.shape == (features, 1), f"Expected ({features}, 1), got {w.shape}"
y = x * w
loss = y.sum()
print(f" x.shape = {x.shape}")
print(f" w.shape = {w.shape}")
print(f" y.shape = {y.shape} # Shape is now correct")
print()
loss.backward()
print(f" w.grad.shape = {w.grad.shape} # Gradient shape matches parameter")
print()
debug_shape_mismatch()
# =============================================================================
# PART 6: Common Bug #4 - NaN/Inf Detection
# =============================================================================
print("6. DEBUGGING: NaN/Inf DETECTION & PREVENTION")
print("-" * 90)
def detect_nan_gradients():
"""Demonstrate NaN detection techniques"""
print("❌ Operations that cause NaN:")
# Example 1: Division by zero
x = torch.tensor([1.0], requires_grad=True)
y = 1.0 / (x - x) # Division by zero!
print(f" 1/(x-x) = {y.item()}") # inf
# Example 2: Log of zero
x = torch.tensor([0.0], requires_grad=True)
y = torch.log(x) # log(0) = -inf
print(f" log(0) = {y.item()}")
# Example 3: Square root of negative
x = torch.tensor([-1.0], requires_grad=True)
y = torch.sqrt(x) # sqrt(-1) = nan
print(f" sqrt(-1) = {y.item()}")
print()
print("✓ CORRECT: Using anomaly detection")
# Enable anomaly detection (slower but helps debugging)
torch.autograd.set_detect_anomaly(True)
try:
x = torch.tensor([1.0], requires_grad=True)
y = torch.log(x - 1) # log(0) will cause issues
z = y ** 2
z.backward()
except RuntimeError as e:
print(f" Anomaly detected: {str(e)[:80]}...")
torch.autograd.set_detect_anomaly(False)
print()
print("✓ PREVENTION: Numerical stability techniques")
# Technique 1: Add epsilon for numerical stability
x = torch.tensor([0.0], requires_grad=True)
eps = 1e-8
y = torch.log(x + eps) # Prevents log(0)
print(f" log(x + ε) = {y.item():.4f} # Stable!")
# Technique 2: Clamp values to safe range
x = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
x_safe = torch.clamp(x, min=eps) # Ensure all values ε
y = torch.log(x_safe)
print(f" log(clamp(x)) = {y.data.numpy()} # All finite!")
print()
detect_nan_gradients()
# =============================================================================
# PART 7: Gradient Checking (Numerical Verification)
# =============================================================================
print("7. GRADIENT CHECKING: Verifying Autograd Correctness")
print("-" * 90)
def gradient_check(func, x: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Verify autograd gradients against numerical gradients
Args:
func: Function to compute
x: Input tensor
eps: Small perturbation for finite differences
Returns:
(analytical_grad, numerical_grad)
"""
# Analytical gradient (autograd)
x.requires_grad = True
y = func(x)
y.backward()
analytical_grad = x.grad.clone()
x.grad.zero_()
# Numerical gradient (finite differences)
numerical_grad = torch.zeros_like(x)
for i in range(x.numel()):
# f(x + eps)
x_plus = x.clone().detach()
x_plus.view(-1)[i] += eps
f_plus = func(x_plus).item()
# f(x - eps)
x_minus = x.clone().detach()
x_minus.view(-1)[i] -= eps
f_minus = func(x_minus).item()
# Centered finite difference
numerical_grad.view(-1)[i] = (f_plus - f_minus) / (2 * eps)
return analytical_grad, numerical_grad
# Test on a complex function
def test_function(x):
"""Complex function: f(x) = sum( - 2x² + 5x)"""
return (x**3 - 2*x**2 + 5*x).sum()
x_test = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
analytical, numerical = gradient_check(test_function, x_test)
print(f"Function: f(x) = - 2x² + 5x")
print(f"At x = {x_test.data.numpy()}")
print()
print(f"Analytical gradient (autograd): {analytical.numpy()}")
print(f"Numerical gradient (finite diff): {numerical.numpy()}")
print(f"Difference: {(analytical - numerical).abs().numpy()}")
print(f"Relative error: {((analytical - numerical).abs() / (analytical.abs() + 1e-8)).numpy()}")
print()
if torch.allclose(analytical, numerical, rtol=1e-4, atol=1e-5):
print("✓ GRADIENT CHECK PASSED! Autograd is correct.")
else:
print("❌ GRADIENT CHECK FAILED! Check your implementation.")
print()
# =============================================================================
# PART 8: Advanced - Custom Gradient Function
# =============================================================================
print("8. ADVANCED: Custom Autograd Functions")
print("-" * 90)
class MyReLU(torch.autograd.Function):
"""
Custom ReLU implementation with manual backward pass
Demonstrates how to write custom differentiable operations
"""
@staticmethod
def forward(ctx, input):
"""Forward pass: ReLU(x) = max(0, x)"""
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
"""Backward pass: derivative is 1 if x > 0, else 0"""
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# Test custom ReLU
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
y = MyReLU.apply(x)
loss = y.sum()
loss.backward()
print(f"Input: {x.data.numpy()}")
print(f"ReLU(x): {y.data.numpy()}")
print(f"Gradient: {x.grad.numpy()}")
print(f" Gradient is 1 where x > 0, else 0")
print()
# =============================================================================
# PART 9: Best Practices Summary
# =============================================================================
print("9. DEBUGGING CHECKLIST")
print("-" * 90)
print("""
Before training any neural network:
1. Print tensor shapes at each layer
2. Verify gradients with gradient checking (at least once)
3. Use torch.autograd.set_detect_anomaly(True) during debugging
4. Check for NaN/Inf after each training step:
- if torch.isnan(loss): raise ValueError("NaN loss!")
5. Avoid .detach(), .numpy(), .item() in computation graph
6. Use x = x + 1, not x += 1 (avoid in-place)
7. Add numerical stability (eps=1e-8) to log/div operations
8. Initialize weights properly (Xavier/He initialization)
9. Start with small learning rate (1e-4) and increase gradually
10. Monitor gradient norms: torch.nn.utils.clip_grad_norm_()
Common NaN Sources:
log(0), log(negative)
1/0, x/0
sqrt(negative)
exp(very_large_number) inf
Learning rate too high exploding gradients
Quick NaN Debug:
1. torch.autograd.set_detect_anomaly(True)
2. Add print(f"Loss: {loss.item()}") after each forward
3. Check: torch.isnan(model.weight.grad).any()
4. Reduce learning rate by 10x
5. Add gradient clipping
""")
print("="*90)
print("COMPLETE! You now understand autograd and can debug gradient issues.")
print("="*90)
Section 1 of 10 • Lesson 1 of 5