import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
# Remove all the warnings
import warnings
'ignore')
warnings.filterwarnings(
# Set env CUDA_LAUNCH_BLOCKING=1
import os
'CUDA_LAUNCH_BLOCKING'] = '1' os.environ[
TLDR: Sine activation function is better than ReLU for reconstructing images
Animation of the training process
!wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O dog.jpg
--2023-04-27 17:21:53-- https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg
Resolving segment-anything.com (segment-anything.com)... 108.138.128.23, 108.138.128.8, 108.138.128.34, ...
Connecting to segment-anything.com (segment-anything.com)|108.138.128.23|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 221810 (217K) [image/jpeg]
Saving to: ‘dog.jpg’
dog.jpg 100%[===================>] 216.61K 400KB/s in 0.5s
2023-04-27 17:21:55 (400 KB/s) - ‘dog.jpg’ saved [221810/221810]
# Read in a image from torchvision
= torchvision.io.read_image("dog.jpg") img
1, 2, 0)) plt.imshow(img.permute(
# Normalize the image
= img / 255.0 img
img.shape
torch.Size([3, 1365, 2048])
# Take a random 224x224 crop of the image
= torchvision.transforms.functional.crop(img, 600, 750, 400, 400)
crop
# Plot the crop
1, 2, 0)) plt.imshow(crop.permute(
= torch.device("cuda" if torch.cuda.is_available() else "cpu") device
# Get the dimensions of the image tensor
= crop.shape
num_channels, height, width
# Create a 2D grid of (x,y) coordinates
= torch.arange(width).repeat(height, 1)
x_coords = torch.arange(height).repeat(width, 1).t()
y_coords = x_coords.reshape(-1)
x_coords = y_coords.reshape(-1)
y_coords
# Combine the x and y coordinates into a single tensor
= torch.stack([x_coords, y_coords], dim=1).float()
X
# Move X to GPU if available
= X.to(device) X
= height * width
num_xy num_xy
160000
X.shape, X
(torch.Size([160000, 2]),
tensor([[ 0., 0.],
[ 1., 0.],
[ 2., 0.],
...,
[397., 399.],
[398., 399.],
[399., 399.]], device='cuda:0'))
# Extract pixel values from image tensor
= crop.reshape(num_channels, -1).float().to(device)
pixel_values
# Transpose the pixel values to be (num_xy, num_channels)
= pixel_values.transpose(0, 1)
pixel_values
= pixel_values.to(device) y
# Create a MLP with 5 hidden layers with 256 neurons each and ReLU activations.
# Input is (x, y) and output is (r, g, b)
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 256)
self.fc4 = nn.Linear(256, 256)
self.fc5 = nn.Linear(256, 256)
self.fc6 = nn.Linear(256, 3)
def forward(self, x):
= F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = F.relu(self.fc5(x))
x return self.fc6(x)
# Training loop function to train the model
# X: (num_xy, 2) tensor of (x, y) coordinates
# y: (num_xy, 3) tensor of (r, g, b) pixel values
# model: MLP model
# lr: learning rate
# epochs: number of epochs to train for
# bs: batch size
# print_every: print loss every print_every epochs
# Logs losses
# Saves the prediction frmo model every print_every epochs
def train(X, y, model, lr=0.01, epochs=1000, bs=1000, print_every=100):
= []
losses = []
imgs = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = nn.MSELoss()
criterion for epoch in range(epochs):
# Get a random batch of (x, y) coordinates
= torch.randperm(num_xy)[:bs]
idxs = X[idxs]
batch_X = y[idxs]
batch_y
# Predict the (r, g, b) values
= model(batch_X)
pred_y
# Compute the loss
= criterion(pred_y, batch_y)
loss
# Zero gradients, perform a backward pass, and update the weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
# Print loss every print_every epochs
if epoch % print_every == 0:
print(f"Epoch {epoch} loss: {loss.item()}")
with torch.no_grad():
# Predict the (r, g, b) values
= model(X)
pred_y
# Reshape the predictions to be (3, height, width)
= pred_y.transpose(0, 1).reshape(num_channels, height, width)
pred_y 1, 2, 0).detach().cpu())
imgs.append(pred_y.permute(
return losses, imgs
= MLP()
m1 = m1.to(device)
m1 = train(X, y, m1, lr=0.001, epochs=4000, bs=2000, print_every=100) losses_mlp, imgs
Epoch 0 loss: 1.5234602689743042
Epoch 100 loss: 0.0640626773238182
Epoch 200 loss: 0.04388527199625969
Epoch 300 loss: 0.03277464583516121
Epoch 400 loss: 0.03183111175894737
Epoch 500 loss: 0.02485758438706398
Epoch 600 loss: 0.023289738222956657
Epoch 700 loss: 0.024606380611658096
Epoch 800 loss: 0.023782318457961082
Epoch 900 loss: 0.026350615546107292
Epoch 1000 loss: 0.025088826194405556
Epoch 1100 loss: 0.023389095440506935
Epoch 1200 loss: 0.02370390295982361
Epoch 1300 loss: 0.023111725226044655
Epoch 1400 loss: 0.023864751681685448
Epoch 1500 loss: 0.021725382655858994
Epoch 1600 loss: 0.021787280216813087
Epoch 1700 loss: 0.021760988980531693
Epoch 1800 loss: 0.021614212542772293
Epoch 1900 loss: 0.020562106743454933
Epoch 2000 loss: 0.019880816340446472
Epoch 2100 loss: 0.01901845820248127
Epoch 2200 loss: 0.018372364342212677
Epoch 2300 loss: 0.01828525774180889
Epoch 2400 loss: 0.018451901152729988
Epoch 2500 loss: 0.01738181710243225
Epoch 2600 loss: 0.01698809117078781
Epoch 2700 loss: 0.01643018051981926
Epoch 2800 loss: 0.01669265516102314
Epoch 2900 loss: 0.01664060726761818
Epoch 3000 loss: 0.01606595516204834
Epoch 3100 loss: 0.01667209528386593
Epoch 3200 loss: 0.015133237466216087
Epoch 3300 loss: 0.014814447611570358
Epoch 3400 loss: 0.01538220327347517
Epoch 3500 loss: 0.01484852284193039
Epoch 3600 loss: 0.01589234732091427
Epoch 3700 loss: 0.014897373504936695
Epoch 3800 loss: 0.014240250922739506
Epoch 3900 loss: 0.015261288732290268
def plot_image(model, name=None):
# Predict the (r, g, b) values
= model(X)
pred_y
# Reshape the predictions to be (3, height, width)
= pred_y.transpose(0, 1).reshape(num_channels, height, width)
pred_y
# plot the image
1, 2, 0).detach().cpu())
plt.imshow(pred_y.permute(if name:
plt.savefig(name)
"mlp_dog.png") plot_image(m1,
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# Create the animation from imgs and save it as a gif
import imageio
'mlp.gif', imgs, fps=10) imageio.mimsave(
Lossy conversion from float32 to uint8. Range [-13.466928482055664, 2.713646650314331]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.18658676743507385, 1.3069090843200684]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.18308542668819427, 1.0001248121261597]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.07874367386102676, 1.0167515277862549]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.09477106481790543, 1.0060935020446777]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.033188510686159134, 1.0109848976135254]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.0989738255739212, 1.0007272958755493]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.04943906515836716, 1.0269501209259033]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.02097826451063156, 1.0289174318313599]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.029821299016475677, 1.0194318294525146]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.016834549605846405, 1.0527536869049072]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.008144930005073547, 1.0191292762756348]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.009020708501338959, 1.0909096002578735]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.017141804099082947, 1.0371521711349487]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.013367637991905212, 1.0438421964645386]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0.0005456805229187012, 1.0179295539855957]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.011109575629234314, 1.0290166139602661]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.020140215754508972, 1.078523874282837]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.0396433025598526, 1.0415352582931519]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.015714898705482483, 1.0283904075622559]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.04321514815092087, 1.0413591861724854]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.04679575562477112, 1.067355990409851]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.003602549433708191, 1.0755447149276733]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.007610529661178589, 1.052262306213379]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.033921219408512115, 1.0815953016281128]. Convert image to uint8 prior to saving to suppress this warning.
# Create a MLP with 5 hidden layers with 256 neurons each and sine activations.
# Input is (x, y) and output is (r, g, b)
class MLP_sin(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 256)
self.fc4 = nn.Linear(256, 256)
self.fc5 = nn.Linear(256, 256)
self.fc6 = nn.Linear(256, 3)
def forward(self, x):
= torch.sin(self.fc1(x))
x = torch.sin(self.fc2(x))
x = torch.sin(self.fc3(x))
x = torch.sin(self.fc4(x))
x = torch.sin(self.fc5(x))
x return self.fc6(x)
= MLP_sin()
m2 = m2.to(device)
m2 = train(X, y, m2, lr=0.001, epochs=4000, bs=1000, print_every=100) losses_mlp_sin, imgs
Epoch 0 loss: 0.40150442719459534
Epoch 100 loss: 0.03298206627368927
Epoch 200 loss: 0.033279214054346085
Epoch 300 loss: 0.03175220638513565
Epoch 400 loss: 0.03205806389451027
Epoch 500 loss: 0.03196191042661667
Epoch 600 loss: 0.02972976118326187
Epoch 700 loss: 0.029925711452960968
Epoch 800 loss: 0.02968132309615612
Epoch 900 loss: 0.028653116896748543
Epoch 1000 loss: 0.02474542148411274
Epoch 1100 loss: 0.020879685878753662
Epoch 1200 loss: 0.019819265231490135
Epoch 1300 loss: 0.016965048387646675
Epoch 1400 loss: 0.013934656977653503
Epoch 1500 loss: 0.011689499020576477
Epoch 1600 loss: 0.010081701911985874
Epoch 1700 loss: 0.007140354719012976
Epoch 1800 loss: 0.006480662152171135
Epoch 1900 loss: 0.005266484338790178
Epoch 2000 loss: 0.004757172428071499
Epoch 2100 loss: 0.003453798359259963
Epoch 2200 loss: 0.0032651633955538273
Epoch 2300 loss: 0.0028410402592271566
Epoch 2400 loss: 0.0026403532829135656
Epoch 2500 loss: 0.0019292739452794194
Epoch 2600 loss: 0.0021367412991821766
Epoch 2700 loss: 0.0020427301060408354
Epoch 2800 loss: 0.0017756932647898793
Epoch 2900 loss: 0.0016549285501241684
Epoch 3000 loss: 0.0016728530172258615
Epoch 3100 loss: 0.001471961266361177
Epoch 3200 loss: 0.0014844941906630993
Epoch 3300 loss: 0.0014798615593463182
Epoch 3400 loss: 0.0012664658715948462
Epoch 3500 loss: 0.0012708695139735937
Epoch 3600 loss: 0.0012460555881261826
Epoch 3700 loss: 0.0012855605455115438
Epoch 3800 loss: 0.001190435141324997
Epoch 3900 loss: 0.0011714434949681163
"mlp_sin_dog.png") plot_image(m2,
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
'mlp_sin.gif', imgs, fps=10) imageio.mimsave(
Lossy conversion from float32 to uint8. Range [-0.1441832184791565, 0.3080734610557556]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.056166477501392365, 0.9270500540733337]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.04645712673664093, 0.9617018103599548]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.08092432469129562, 0.9469475746154785]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.13254448771476746, 1.0228846073150635]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.18537408113479614, 1.0271779298782349]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.15940740704536438, 1.069307804107666]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.1629665046930313, 1.0901581048965454]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.17787247896194458, 1.164113163948059]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.23600360751152039, 1.1689845323562622]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.1829279065132141, 1.1432479619979858]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.12739746272563934, 1.1281737089157104]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.11645704507827759, 1.1141674518585205]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.11797109246253967, 1.1277530193328857]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.09862736612558365, 1.0859858989715576]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.1146015003323555, 1.099491834640503]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.09405502676963806, 1.1023061275482178]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.132747620344162, 1.0877472162246704]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.11511929333209991, 1.0887328386306763]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.11015606671571732, 1.0807398557662964]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.09713895618915558, 1.087331771850586]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.0733504444360733, 1.0549205541610718]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.07674040645360947, 1.0766404867172241]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.07997756451368332, 1.0550076961517334]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.09363748133182526, 1.056591510772705]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.08970168232917786, 1.0528484582901]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.08736599236726761, 1.04934561252594]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.08859498053789139, 1.0708154439926147]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.08006224036216736, 1.0856648683547974]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.08170387893915176, 1.071043610572815]. Convert image to uint8 prior to saving to suppress this warning.
Lossy conversion from float32 to uint8. Range [-0.06969650834798813, 1.0583616495132446]. Convert image to uint8 prior to saving to suppress this warning.
# Audio
!wget https://www.vincentsitzmann.com/siren/img/audio/gt_bach.wav
--2023-04-28 14:24:10-- https://www.vincentsitzmann.com/siren/img/audio/gt_bach.wav
Resolving www.vincentsitzmann.com (www.vincentsitzmann.com)... 185.199.111.153, 185.199.108.153, 185.199.110.153, ...
Connecting to www.vincentsitzmann.com (www.vincentsitzmann.com)|185.199.111.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1232886 (1.2M) [audio/wav]
Saving to: ‘gt_bach.wav.3’
gt_bach.wav.3 100%[===================>] 1.17M --.-KB/s in 0.06s
2023-04-28 14:24:10 (19.7 MB/s) - ‘gt_bach.wav.3’ saved [1232886/1232886]
# CLear CUDA cache
torch.cuda.empty_cache()
from IPython.display import Audio
'gt_bach.wav') Audio(
# Read the audio file
import torchaudio
= torchaudio.load('gt_bach.wav') audio, sr
sr
44100
audio.shape= audio[0]
audio = audio.to(device) audio
# use last 2 seconds of audio
= audio[-2 * sr:]
audio = torch.arange(0, len(audio)).unsqueeze(1).float().to(device)
X
# Rescale X between -10 and 10
= X / X.max() * 20 - 10
X min(), X.max() X.
(tensor(-10., device='cuda:0'), tensor(10., device='cuda:0'))
X.shape, audio.shape, X
(torch.Size([88200, 1]),
torch.Size([88200]),
tensor([[-10.0000],
[ -9.9998],
[ -9.9995],
...,
[ 9.9995],
[ 9.9998],
[ 10.0000]], device='cuda:0'))
=sr) Audio(audio.cpu(), rate
class SinActivation(torch.nn.Module):
def __init__(self):
super(SinActivation, self).__init__()
return
def forward(self, x):
return torch.sin(x)
class SinActivation30(torch.nn.Module):
def __init__(self):
super(SinActivation30, self).__init__()
return
def forward(self, x):
return torch.sin(30*x)
import torch.nn as nn
def create_mlp(n, m, f):
"""
n: number of hidden layers
m: number of neurons in each hidden layer
f: activation function
---
Weighing initialization:
uniform distribution between -30/input_dim and 30/input_dim for first layer
-sqrt(6/input_dim) and sqrt(6/input_dim) for the rest
Weight init is done in the forward pass
"""
= []
layers = nn.Linear(1, m)
layer1 =-1/1, b=1/1)
torch.nn.init.uniform_(layer1.weight, a#torch.nn.init.uniform_(layer1.bias, a=-1/1, b=1/1)
layers.append(layer1)
layers.append(SinActivation30())for i in range(n):
= nn.Linear(m, m)
layer_i # Uniform distribution between -sqrt(6/input_dim) and sqrt(6/input_dim)
=-np.sqrt(6/m), b=np.sqrt(6/m))
torch.nn.init.uniform_(layer_i.weight, a=-np.sqrt(6/m), b=np.sqrt(6/m))
torch.nn.init.uniform_(layer_i.bias, a
layers.append(layer_i)
layers.append(f)1))
layers.append(nn.Linear(m,
return nn.Sequential(*layers)
= create_mlp(5, 256, SinActivation()).to(device)
mlp_audio_sin_5_256 #mlp_audio_sin_8_512 = create_mlp(8, 512, SinActivation()).to(device)
#mlp_audio_sin_3_128 = create_mlp(3, 128, SinActivation()).to(device)
mlp_audio_sin_5_128
NameError: name 'mlp_audio_sin_5_128' is not defined
def train_audio(X, y, model, lr=0.01, epochs=1000, bs=1000, print_every=100):
= []
losses = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = nn.MSELoss()
criterion for epoch in range(epochs):
= X.shape[0]
num_rows = torch.randperm(num_rows)[:bs]
idx = X[idx]
batch_X = y[idx]
batch_y = model(batch_X)
pred_y
# Compute the loss
= criterion(pred_y, batch_y)
loss
# Zero gradients, perform a backward pass, and update the weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
# Print loss every print_every epochs
if epoch % print_every == 0:
print(f"Epoch {epoch} loss: {loss.item()}")
return losses
#losses_mlp_sin_3_128 = train_audio(X, audio, mlp_audio_sin_3_128, lr=0.0001,
# epochs=5000, bs=len(X)//2, print_every=100)
= train_audio(X, audio, mlp_audio_sin_5_256, lr=0.0001,
losses_mlp_sin_5_256 =5000, bs=len(X)//2, print_every=100) epochs
Epoch 0 loss: 0.210729718208313
OutOfMemoryError: CUDA out of memory. Tried to allocate 7.25 GiB (GPU 0; 79.18 GiB total capacity; 63.06 GiB already allocated; 7.88 MiB free; 74.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
X
tensor([[-1.7320],
[-1.7320],
[-1.7319],
...,
[ 1.7319],
[ 1.7320],
[ 1.7320]], device='cuda:0')
import time
= time.time()
a = train_audio(X, audio, mlp_audio_sin_8_512,
losses_mlp_sin_8_512 =0.0001, epochs=10, bs=len(X), print_every=1)
lr= time.time()
b print(b-a)
OutOfMemoryError: CUDA out of memory. Tried to allocate 28.98 GiB (GPU 0; 79.18 GiB total capacity; 33.40 GiB already allocated; 14.51 GiB free; 59.74 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
= torch.compile(mlp_audio_sin_8_512) mlp_audio_sin_8_512
= time.time()
a = train_audio(X, audio, mlp_audio_sin_8_512,
losses_mlp_sin_8_512 =0.0001, epochs=10, bs=len(X), print_every=1)
lr= time.time()
b print(b-a)
NameError: name 'time' is not defined
# Plot the reconstruction
with torch.no_grad():
#pred_y_5_256 = mlp_audio_sin_5_256(X)
#pred_y_8_512 = mlp_audio_sin_8_512(X)
= mlp_audio_sin_3_128(X)
pred_y_3_128 ="Ground truth")
plt.plot(audio.cpu().numpy(), label#plt.plot(pred_y_5_256.cpu().numpy(), label="MLP 5 layers 256 neurons")
="MLP 8 layers 512 neurons")
plt.plot(pred_y_3_128.cpu().numpy(), label plt.legend()
import pandas as pd
= pd.DataFrame({"GT audio": audio.cpu().numpy(),
df "MLP 5 layers 256 neurons": pred_y_5_256.cpu().numpy().flatten(),
"MLP 8 layers 512 neurons": pred_y_8_512.cpu().numpy().flatten()})
df.describe()
GT audio | MLP 5 layers 256 neurons | MLP 8 layers 512 neurons | |
---|---|---|---|
count | 88200.000000 | 88200.000000 | 88200.000000 |
mean | 0.000127 | -0.013929 | -0.010819 |
std | 0.208728 | 0.025773 | 0.156109 |
min | -0.868308 | -0.083747 | -0.710084 |
25% | -0.130095 | -0.030821 | -0.116540 |
50% | -0.002093 | -0.011080 | -0.010339 |
75% | 0.130701 | 0.002974 | 0.094733 |
max | 1.000000 | 0.051832 | 0.658187 |
audio.shape, pred_y_8_512.shape
(torch.Size([88200]), torch.Size([88200, 1]))
# Play the reconstruction
=sr) Audio(pred_y_8_512.cpu().T, rate
TODO
- Show the gradient of the reconstructed image for different activation functions