Autodiff

Autodiff Helper
Author

Nipun Batra

Published

April 4, 2023

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Set random seed for reproducibility
torch.manual_seed(0)

# Torch version
torch.__version__
'2.0.1'
### Derviatives using numerical differentiation

def f(x):
    return 3 * x ** 2 + 2 * x + 1

def numerical_derivative_single_side(f, x, h=0.001):
    return (f(x + h) - f(x)) / h

def numerical_derivative_double_side(f, x, h=0.001):
    return (f(x + h) - f(x - h)) / (2 * h)

x = torch.tensor(2.0, requires_grad=False)
print(f'f\'(2) = {numerical_derivative_single_side(f, x, 0.00001)}')
print(f'f\'(2) = {numerical_derivative_double_side(f, x, 0.00001)}')
f'(2) = 14.1143798828125
f'(2) = 14.1143798828125
def f(theta_0, theta_1, theta_2, theta_3, theta_4):
    return theta_0 + 2 * theta_1 + 3 * theta_2 + 4 * theta_3 + 5 * theta_4



theta_0 = torch.tensor(1.0, requires_grad=False)
theta_1 = torch.tensor(2.0, requires_grad=False)
theta_2 = torch.tensor(3.0, requires_grad=False)
theta_3 = torch.tensor(4.0, requires_grad=False)
theta_4 = torch.tensor(5.0, requires_grad=False)


df_dtheta_0 = numerical_derivative_single_side(lambda theta_0: f(theta_0, theta_1, theta_2, theta_3, theta_4), theta_0, 0.0001)
df_dtheta_0
tensor(0.9918)
## Above method is very expensive and not practical for large number of parameters
theta_0 = torch.tensor(1.0, requires_grad=True)
theta_1 = torch.tensor(1.0, requires_grad=True)
theta_2 = torch.tensor(2.0, requires_grad=True)

x1 = torch.tensor(1.0)
x2 = torch.tensor(2.0)

f1 = theta_1*x1
f2 = theta_2*x2

f3 = f1 + f2

f4 = f3 + theta_0

f5 = f4*-1

f6 = torch.exp(f5)

f7 = 1 + f6

f8 = 1/f7

f9 = torch.log(f8)

L = f9*-1

all_nodes = {"theta_0": theta_0, "theta_1": theta_1, "theta_2": theta_2,  
             "f1": f1, "f2": f2, "f3": f3, "f4": f4, "f5": f5, "f6": f6, "f7": f7, "f8": f8, "f9": f9, "L": L}

# Retain grad for all nodes
for node in all_nodes.values():
    node.retain_grad()
# Print out the function evaluation for all nodes along with name of the node
for name, node in all_nodes.items():
    print(f"{name}: {node.item()}")
theta_0: 1.0
theta_1: 1.0
theta_2: 2.0
f1: 1.0
f2: 4.0
f3: 5.0
f4: 6.0
f5: -6.0
f6: 0.0024787522852420807
f7: 1.0024787187576294
f8: 0.9975274205207825
f9: -0.0024756414350122213
L: 0.0024756414350122213
L.backward()

# Print out the gradient for all nodes along with name of the node
for name, node in all_nodes.items():
    print(f"{name}: {node.grad.item()}")
theta_0: -0.00247262348420918
theta_1: -0.00247262348420918
theta_2: -0.00494524696841836
f1: -0.00247262348420918
f2: -0.00247262348420918
f3: -0.00247262348420918
f4: -0.00247262348420918
f5: 0.00247262348420918
f6: 0.9975274801254272
f7: 0.9975274801254272
f8: -1.0024787187576294
f9: -1.0
L: 1.0
(-1/(f7**2))*-1.00247
tensor(0.9975, grad_fn=<MulBackward0>)
torch.exp(f5)*0.9975
tensor(0.0025, grad_fn=<MulBackward0>)
### Micrograd demo: https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
### Example to illustrate accumulation of gradients

theta = torch.tensor(1.0, requires_grad=True)

x1 = torch.tensor(1.0)
x2 = torch.tensor(2.0)

L1 = theta*x1
L2 = theta*x2

L = L1 + L2
L.backward()
theta.grad
tensor(3.)

Why do we need to use torch.no_grad() in the test phase?

Why do we need to zero out the gradients in the training phase after each update?

model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Dummy data
inputs = torch.randn((100, 10))
targets = torch.randn((100, 1))

# Training loop
for epoch in range(100):
    # Forward pass
    outputs = model(inputs)
    
    # Compute the loss
    loss = criterion(outputs, targets)
    
    # Zero the gradients
    optimizer.zero_grad()
    
    # Backward pass
    loss.backward()
    
    # Update the weights
    optimizer.step()

    # Print the loss every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')