Sirens v/s Linear Regression

ML
Author

Nipun Batra

Published

January 21, 2024

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'

try:
    from einops import rearrange
except ImportError:
    %pip install einops
    from einops import rearrange
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import torch
      2 import torchvision
      3 import torchvision.transforms as transforms

ModuleNotFoundError: No module named 'torch'
if os.path.exists('dog.jpg'):
    print('dog.jpg exists')
else:
    !wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O dog.jpg
dog.jpg exists
# Read in a image from torchvision
img = torchvision.io.read_image("dog.jpg")
print(img.shape)
torch.Size([3, 1365, 2048])
plt.imshow(rearrange(img, 'c h w -> h w c').numpy())

from sklearn import preprocessing

scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
scaler_img
MinMaxScaler()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
img_scaled = scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
img_scaled.shape

img_scaled = torch.tensor(img_scaled)
img_scaled = img_scaled.to(device)
img_scaled
tensor([[[0.3098, 0.3137, 0.3137,  ..., 0.2941, 0.2941, 0.2980],
         [0.3098, 0.3137, 0.3137,  ..., 0.2941, 0.2941, 0.2980],
         [0.3098, 0.3137, 0.3137,  ..., 0.2941, 0.2941, 0.2980],
         ...,
         [0.4745, 0.4745, 0.4784,  ..., 0.3804, 0.3765, 0.3765],
         [0.4745, 0.4745, 0.4784,  ..., 0.3804, 0.3804, 0.3765],
         [0.4745, 0.4745, 0.4784,  ..., 0.3843, 0.3804, 0.3804]],

        [[0.2039, 0.2078, 0.2078,  ..., 0.2157, 0.2157, 0.2118],
         [0.2039, 0.2078, 0.2078,  ..., 0.2157, 0.2157, 0.2118],
         [0.2039, 0.2078, 0.2078,  ..., 0.2157, 0.2157, 0.2118],
         ...,
         [0.4039, 0.4039, 0.4078,  ..., 0.3216, 0.3176, 0.3176],
         [0.4039, 0.4039, 0.4078,  ..., 0.3216, 0.3216, 0.3176],
         [0.4039, 0.4039, 0.4078,  ..., 0.3255, 0.3216, 0.3216]],

        [[0.1373, 0.1412, 0.1412,  ..., 0.1176, 0.1176, 0.1176],
         [0.1373, 0.1412, 0.1412,  ..., 0.1176, 0.1176, 0.1176],
         [0.1373, 0.1412, 0.1412,  ..., 0.1176, 0.1176, 0.1176],
         ...,
         [0.1451, 0.1451, 0.1490,  ..., 0.1686, 0.1647, 0.1647],
         [0.1451, 0.1451, 0.1490,  ..., 0.1686, 0.1686, 0.1647],
         [0.1451, 0.1451, 0.1490,  ..., 0.1725, 0.1686, 0.1686]]],
       device='cuda:0', dtype=torch.float64)
crop = torchvision.transforms.functional.crop(img_scaled.cpu(), 600, 800, 300, 300)
crop.shape
torch.Size([3, 300, 300])
plt.imshow(rearrange(crop, 'c h w -> h w c').cpu().numpy())

crop = crop.to(device)
# Get the dimensions of the image tensor
num_channels, height, width = crop.shape
print(num_channels, height, width)
3 300 300
num_channels, height, width = 2, 3, 4

    
# Create a 2D grid of (x,y) coordinates
w_coords = torch.arange(width).repeat(height, 1)
h_coords = torch.arange(height).repeat(width, 1).t()
w_coords = w_coords.reshape(-1)
h_coords = h_coords.reshape(-1)

# Combine the x and y coordinates into a single tensor
X = torch.stack([h_coords, w_coords], dim=1).float()
X.shape
torch.Size([12, 2])
def create_coordinate_map(img):
    """
    img: torch.Tensor of shape (num_channels, height, width)
    
    return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
    """
    
    num_channels, height, width = img.shape
    
    # Create a 2D grid of (x,y) coordinates (h, w)
    # width values change faster than height values
    w_coords = torch.arange(width).repeat(height, 1)
    h_coords = torch.arange(height).repeat(width, 1).t()
    w_coords = w_coords.reshape(-1)
    h_coords = h_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([h_coords, w_coords], dim=1).float()

    # Move X to GPU if available
    X = X.to(device)

    # Reshape the image to (h * w, num_channels)
    Y = rearrange(img, 'c h w -> (h w) c').float()
    return X, Y
dog_X, dog_Y = create_coordinate_map(crop)

dog_X.shape, dog_Y.shape
(torch.Size([90000, 2]), torch.Size([90000, 3]))
# MinMaxScaler from -1 to 1
scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(dog_X.cpu())

# Scale the X coordinates
dog_X_scaled = scaler_X.transform(dog_X.cpu())

# Move the scaled X coordinates to the GPU
dog_X_scaled = torch.tensor(dog_X_scaled).to(device)

# Set to dtype float32
dog_X_scaled = dog_X_scaled.float()
class LinearModel(nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, x):
        return self.linear(x)
    
net = LinearModel(2, 3)
net.to(device)
LinearModel(
  (linear): Linear(in_features=2, out_features=3, bias=True)
)
def train(net, lr, X, Y, epochs, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = net(X)
        
        
        loss = criterion(outputs, Y)
        loss.backward()
        optimizer.step()
        if verbose and epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item():.6f}")
    return loss.item()
train(net, 0.01, dog_X_scaled, dog_Y, 1000)
Epoch 0 loss: 0.504670
Epoch 100 loss: 0.046783
Epoch 200 loss: 0.036839
Epoch 300 loss: 0.036823
Epoch 400 loss: 0.036823
Epoch 500 loss: 0.036823
Epoch 600 loss: 0.036823
Epoch 700 loss: 0.036823
Epoch 800 loss: 0.036823
Epoch 900 loss: 0.036823
0.03682254999876022
def plot_reconstructed_and_original_image(original_img, net, X, title=""):
    """
    net: torch.nn.Module
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    num_channels, height, width = original_img.shape
    net.eval()
    with torch.no_grad():
        outputs = net(X)
        outputs = outputs.reshape(height, width, num_channels)
        #outputs = outputs.permute(1, 2, 0)
    fig = plt.figure(figsize=(6, 4))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    ax0.imshow(outputs.cpu())
    ax0.set_title("Reconstructed Image")
    

    ax1.imshow(original_img.cpu().permute(1, 2, 0))
    ax1.set_title("Original Image")
    
    for a in [ax0, ax1]:
        a.axis("off")


    fig.suptitle(title, y=0.9)
    plt.tight_layout()
plot_reconstructed_and_original_image(crop, net, dog_X_scaled, title="Reconstructed Image")

# Use polynomial features of degree "d"

def poly_features(X, degree):
    """
    X: torch.Tensor of shape (num_samples, 2)
    degree: int
    
    return: torch.Tensor of shape (num_samples, degree * (degree + 1) / 2)
    """
    X1 = X[:, 0]
    X2 = X[:, 1]
    X1 = X1.unsqueeze(1)
    X2 = X2.unsqueeze(1)
    X = torch.cat([X1, X2], dim=1)
    poly = preprocessing.PolynomialFeatures(degree=degree)
    X = poly.fit_transform(X.cpu())
    return torch.tensor(X, dtype=torch.float32).to(device)
dog_X_scaled_poly = poly_features(dog_X_scaled, 50)
dog_X_scaled_poly.dtype, dog_X_scaled_poly.shape, dog_Y.shape, dog_Y.dtype
(torch.float32,
 torch.Size([90000, 1326]),
 torch.Size([90000, 3]),
 torch.float32)
net = LinearModel(dog_X_scaled_poly.shape[1], 3)
net.to(device)

train(net, 0.005, dog_X_scaled_poly, dog_Y, 1500)
Epoch 0 loss: 0.353235
Epoch 100 loss: 0.028444
Epoch 200 loss: 0.025136
Epoch 300 loss: 0.024183
Epoch 400 loss: 0.023526
Epoch 500 loss: 0.023012
Epoch 600 loss: 0.022591
Epoch 700 loss: 0.022229
Epoch 800 loss: 0.021912
Epoch 900 loss: 0.021658
Epoch 1000 loss: 0.021389
Epoch 1100 loss: 0.021166
Epoch 1200 loss: 0.020970
Epoch 1300 loss: 0.020785
Epoch 1400 loss: 0.020631
0.020467281341552734
plot_reconstructed_and_original_image(crop, net, dog_X_scaled_poly, title="Reconstructed Image with Polynomial Features")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

# create RFF features
def create_rff_features(X, num_features, sigma):
    from sklearn.kernel_approximation import RBFSampler
    rff = RBFSampler(n_components=num_features, gamma=1/(2 * sigma**2))
    X = X.cpu().numpy()
    X = rff.fit_transform(X)
    return torch.tensor(X, dtype=torch.float32).to(device)
X_rff = create_rff_features(dog_X_scaled, 37500, 0.008)
X_rff.shape
torch.Size([90000, 37500])
net = LinearModel(X_rff.shape[1], 3)
net.to(device)

train(net, 0.005, X_rff, dog_Y, 2500)
Epoch 0 loss: 0.375324
Epoch 100 loss: 0.047630
Epoch 200 loss: 0.009158
Epoch 300 loss: 0.003941
Epoch 400 loss: 0.002057
Epoch 500 loss: 0.001118
Epoch 600 loss: 0.000640
Epoch 700 loss: 0.000404
Epoch 800 loss: 0.000292
Epoch 900 loss: 0.000241
Epoch 1000 loss: 0.000218
Epoch 1100 loss: 0.000208
Epoch 1200 loss: 0.000204
Epoch 1300 loss: 0.000201
Epoch 1400 loss: 0.000200
Epoch 1500 loss: 0.000199
Epoch 1600 loss: 0.000198
Epoch 1700 loss: 0.000197
Epoch 1800 loss: 0.000196
Epoch 1900 loss: 0.000195
Epoch 2000 loss: 0.000195
Epoch 2100 loss: 0.000194
Epoch 2200 loss: 0.000194
Epoch 2300 loss: 0.000193
Epoch 2400 loss: 0.000193
0.00019248582248110324
plot_reconstructed_and_original_image(crop, net, X_rff, title="Reconstructed Image with RFF Features")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

w = 100
scale=2
torch.arange(0, w, 1/scale)
tensor([ 0.0000,  0.5000,  1.0000,  1.5000,  2.0000,  2.5000,  3.0000,  3.5000,
         4.0000,  4.5000,  5.0000,  5.5000,  6.0000,  6.5000,  7.0000,  7.5000,
         8.0000,  8.5000,  9.0000,  9.5000, 10.0000, 10.5000, 11.0000, 11.5000,
        12.0000, 12.5000, 13.0000, 13.5000, 14.0000, 14.5000, 15.0000, 15.5000,
        16.0000, 16.5000, 17.0000, 17.5000, 18.0000, 18.5000, 19.0000, 19.5000,
        20.0000, 20.5000, 21.0000, 21.5000, 22.0000, 22.5000, 23.0000, 23.5000,
        24.0000, 24.5000, 25.0000, 25.5000, 26.0000, 26.5000, 27.0000, 27.5000,
        28.0000, 28.5000, 29.0000, 29.5000, 30.0000, 30.5000, 31.0000, 31.5000,
        32.0000, 32.5000, 33.0000, 33.5000, 34.0000, 34.5000, 35.0000, 35.5000,
        36.0000, 36.5000, 37.0000, 37.5000, 38.0000, 38.5000, 39.0000, 39.5000,
        40.0000, 40.5000, 41.0000, 41.5000, 42.0000, 42.5000, 43.0000, 43.5000,
        44.0000, 44.5000, 45.0000, 45.5000, 46.0000, 46.5000, 47.0000, 47.5000,
        48.0000, 48.5000, 49.0000, 49.5000, 50.0000, 50.5000, 51.0000, 51.5000,
        52.0000, 52.5000, 53.0000, 53.5000, 54.0000, 54.5000, 55.0000, 55.5000,
        56.0000, 56.5000, 57.0000, 57.5000, 58.0000, 58.5000, 59.0000, 59.5000,
        60.0000, 60.5000, 61.0000, 61.5000, 62.0000, 62.5000, 63.0000, 63.5000,
        64.0000, 64.5000, 65.0000, 65.5000, 66.0000, 66.5000, 67.0000, 67.5000,
        68.0000, 68.5000, 69.0000, 69.5000, 70.0000, 70.5000, 71.0000, 71.5000,
        72.0000, 72.5000, 73.0000, 73.5000, 74.0000, 74.5000, 75.0000, 75.5000,
        76.0000, 76.5000, 77.0000, 77.5000, 78.0000, 78.5000, 79.0000, 79.5000,
        80.0000, 80.5000, 81.0000, 81.5000, 82.0000, 82.5000, 83.0000, 83.5000,
        84.0000, 84.5000, 85.0000, 85.5000, 86.0000, 86.5000, 87.0000, 87.5000,
        88.0000, 88.5000, 89.0000, 89.5000, 90.0000, 90.5000, 91.0000, 91.5000,
        92.0000, 92.5000, 93.0000, 93.5000, 94.0000, 94.5000, 95.0000, 95.5000,
        96.0000, 96.5000, 97.0000, 97.5000, 98.0000, 98.5000, 99.0000, 99.5000])
def create_coordinate_map(img, scale=1):
    """
    img: torch.Tensor of shape (num_channels, height, width)
    
    return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
    """
    
    num_channels, height, width = img.shape
    
    # Create a 2D grid of (x,y) coordinates (h, w)
    # width values change faster than height values
    w_coords = torch.arange(0, width,  1/scale).repeat(int(height*scale), 1)
    h_coords = torch.arange(0, height, 1/scale).repeat(int(width*scale), 1).t()
    w_coords = w_coords.reshape(-1)
    h_coords = h_coords.reshape(-1)

    # Combine the x and y coordinates into a single tensor
    X = torch.stack([h_coords, w_coords], dim=1).float()

    # Move X to GPU if available
    X = X.to(device)

    # Reshape the image to (h * w, num_channels)
    Y = rearrange(img, 'c h w -> (h w) c').float()
    return X, Y
create_coordinate_map(crop, scale=2)[0].shape
torch.Size([360000, 2])
create_coordinate_map(crop, scale=1)[0].shape
torch.Size([90000, 2])