import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
# Remove all the warnings
import warnings
'ignore')
warnings.filterwarnings(
# Set env CUDA_LAUNCH_BLOCKING=1
import os
'CUDA_LAUNCH_BLOCKING'] = '1'
os.environ[= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
# Retina display
%config InlineBackend.figure_format = 'retina'
try:
from einops import rearrange
except ImportError:
%pip install einops
from einops import rearrange
Hypernetwork
if os.path.exists('dog.jpg'):
print('dog.jpg exists')
else:
!wget https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg -O dog.jpg
dog.jpg exists
# Read in a image from torchvision
= torchvision.io.read_image("dog.jpg")
img print(img.shape)
torch.Size([3, 1365, 2048])
# Show the image
'c h w -> h w c').numpy()) plt.imshow(rearrange(img,
<matplotlib.image.AxesImage at 0x7fc2316addf0>
# Use sklearn to normalize the image and store the transform to be used later
from sklearn import preprocessing
= preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
scaler_img scaler_img
MinMaxScaler()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
MinMaxScaler()
= scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
img_scaled
img_scaled.shape
= torch.tensor(img_scaled) img_scaled
img_scaled.shape
torch.Size([3, 1365, 2048])
= img_scaled.to(device)
img_scaled img_scaled
tensor([[[0.3098, 0.3137, 0.3137, ..., 0.2941, 0.2941, 0.2980],
[0.3098, 0.3137, 0.3137, ..., 0.2941, 0.2941, 0.2980],
[0.3098, 0.3137, 0.3137, ..., 0.2941, 0.2941, 0.2980],
...,
[0.4745, 0.4745, 0.4784, ..., 0.3804, 0.3765, 0.3765],
[0.4745, 0.4745, 0.4784, ..., 0.3804, 0.3804, 0.3765],
[0.4745, 0.4745, 0.4784, ..., 0.3843, 0.3804, 0.3804]],
[[0.2039, 0.2078, 0.2078, ..., 0.2157, 0.2157, 0.2118],
[0.2039, 0.2078, 0.2078, ..., 0.2157, 0.2157, 0.2118],
[0.2039, 0.2078, 0.2078, ..., 0.2157, 0.2157, 0.2118],
...,
[0.4039, 0.4039, 0.4078, ..., 0.3216, 0.3176, 0.3176],
[0.4039, 0.4039, 0.4078, ..., 0.3216, 0.3216, 0.3176],
[0.4039, 0.4039, 0.4078, ..., 0.3255, 0.3216, 0.3216]],
[[0.1373, 0.1412, 0.1412, ..., 0.1176, 0.1176, 0.1176],
[0.1373, 0.1412, 0.1412, ..., 0.1176, 0.1176, 0.1176],
[0.1373, 0.1412, 0.1412, ..., 0.1176, 0.1176, 0.1176],
...,
[0.1451, 0.1451, 0.1490, ..., 0.1686, 0.1647, 0.1647],
[0.1451, 0.1451, 0.1490, ..., 0.1686, 0.1686, 0.1647],
[0.1451, 0.1451, 0.1490, ..., 0.1725, 0.1686, 0.1686]]],
device='cuda:0', dtype=torch.float64)
= torchvision.transforms.functional.crop(img_scaled.cpu(), 600, 750, 400, 400)
crop crop.shape
torch.Size([3, 400, 400])
# Plot the crop using matplotlib and using torch.einsum to convert the image
# from C, H, W to H, W, C
'c h w -> h w c').cpu().numpy()) plt.imshow(rearrange(crop,
<matplotlib.image.AxesImage at 0x7fc4461d7700>
= crop.to(device) crop
# Get the dimensions of the image tensor
= crop.shape
num_channels, height, width print(num_channels, height, width)
3 400 400
Let us now write a function to generate the coordinate inputs. We want to first have changes in the y coordinate and then in the x coordinate. This is equivlent to lower bit changes first and then higher bit changes.
= 2, 3, 4
num_channels, height, width
# Create a 2D grid of (x,y) coordinates
= torch.arange(width).repeat(height, 1)
w_coords = torch.arange(height).repeat(width, 1).t()
h_coords = w_coords.reshape(-1)
w_coords = h_coords.reshape(-1)
h_coords
# Combine the x and y coordinates into a single tensor
= torch.stack([h_coords, w_coords], dim=1).float() X
X
tensor([[0., 0.],
[0., 1.],
[0., 2.],
[0., 3.],
[1., 0.],
[1., 1.],
[1., 2.],
[1., 3.],
[2., 0.],
[2., 1.],
[2., 2.],
[2., 3.]])
def create_coordinate_map(img):
"""
img: torch.Tensor of shape (num_channels, height, width)
return: tuple of torch.Tensor of shape (height * width, 2) and torch.Tensor of shape (height * width, num_channels)
"""
= img.shape
num_channels, height, width
# Create a 2D grid of (x,y) coordinates (h, w)
# width values change faster than height values
= torch.arange(width).repeat(height, 1)
w_coords = torch.arange(height).repeat(width, 1).t()
h_coords = w_coords.reshape(-1)
w_coords = h_coords.reshape(-1)
h_coords
# Combine the x and y coordinates into a single tensor
= torch.stack([h_coords, w_coords], dim=1).float()
X
# Move X to GPU if available
= X.to(device)
X
# Reshape the image to (h * w, num_channels)
= rearrange(img, 'c h w -> (h w) c').float()
Y return X, Y
= create_coordinate_map(crop)
dog_X, dog_Y
dog_X.shape, dog_Y.shape
(torch.Size([160000, 2]), torch.Size([160000, 3]))
dog_X
tensor([[ 0., 0.],
[ 0., 1.],
[ 0., 2.],
...,
[399., 397.],
[399., 398.],
[399., 399.]], device='cuda:0')
# MinMaxScaler from -1 to 1
= preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(dog_X.cpu())
scaler_X
# Scale the X coordinates
= scaler_X.transform(dog_X.cpu())
dog_X_scaled
# Move the scaled X coordinates to the GPU
= torch.tensor(dog_X_scaled).to(device)
dog_X_scaled
# Set to dtype float32
= dog_X_scaled.float() dog_X_scaled
dog_X_scaled.shape, dog_Y.shape
(torch.Size([160000, 2]), torch.Size([160000, 3]))
2], dog_X_scaled[:2], dog_Y[:2] dog_X[:
(tensor([[0., 0.],
[0., 1.]], device='cuda:0'),
tensor([[-1.0000, -1.0000],
[-1.0000, -0.9950]], device='cuda:0'),
tensor([[0.7686, 0.6941, 0.4745],
[0.7686, 0.6941, 0.4745]], device='cuda:0'))
# Create a MLP with 5 hidden layers with 256 neurons each and ReLU activations.
# Input is (x, y) and output is (r, g, b) or (g) for grayscale
= 128
s
class NN(nn.Module):
def _init_siren(self, activation_scale):
self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
for layers in [self.fc2, self.fc3, self.fc4, self.fc5]:
-np.sqrt(6/self.fc2.in_features)/activation_scale,
layers.weight.data.uniform_(6/self.fc2.in_features)/activation_scale)
np.sqrt(
def __init__(self, activation=torch.sin, n_out=1, activation_scale=1.0):
super().__init__()
self.activation = activation
self.activation_scale = activation_scale
self.fc1 = nn.Linear(2, s)
self.fc2 = nn.Linear(s, s)
self.fc3 = nn.Linear(s, s)
self.fc4 = nn.Linear(s, s)
self.fc5 = nn.Linear(s, n_out) #gray scale image (1) or RGB (3)
if self.activation == torch.sin:
# init weights and biases for sine activation
self._init_siren(activation_scale=self.activation_scale)
def forward(self, x):
= self.activation(self.activation_scale*self.fc1(x))
x = self.activation(self.activation_scale*self.fc2(x))
x = self.activation(self.activation_scale*self.fc3(x))
x = self.activation(self.activation_scale*self.fc4(x))
x return self.fc5(x)
# Shuffle data
# shuffled index
= torch.randperm(dog_X_scaled.shape[0])
sh_index
# Shuffle the dataset
= dog_X_scaled[sh_index]
dog_X_sh = dog_Y[sh_index] dog_Y_sh
= {}
nns "dog"] = {}
nns["dog"]["relu"] = NN(activation=torch.relu, n_out=3).to(device)
nns["dog"]["sin"] = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device) nns[
"dog"]["relu"](dog_X_sh).shape, nns["dog"]["sin"](dog_X_sh).shape nns[
(torch.Size([160000, 3]), torch.Size([160000, 3]))
= 2200 n_iter
def train(net, lr, X, Y, epochs, verbose=True):
"""
net: torch.nn.Module
lr: float
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
= nn.MSELoss()
criterion = torch.optim.Adam(net.parameters(), lr=lr)
optimizer for epoch in range(epochs):
optimizer.zero_grad()= net(X)
outputs
= criterion(outputs, Y)
loss
loss.backward()
optimizer.step()if verbose and epoch % 100 == 0:
print(f"Epoch {epoch} loss: {loss.item():.6f}")
return loss.item()
"dog"]["relu"], lr=3e-4, X=dog_X_sh, Y=dog_Y_sh, epochs=n_iter) train(nns[
Epoch 0 loss: 0.362823
Epoch 100 loss: 0.030849
Epoch 200 loss: 0.027654
Epoch 300 loss: 0.024760
Epoch 400 loss: 0.021445
Epoch 500 loss: 0.017849
Epoch 600 loss: 0.015446
Epoch 700 loss: 0.013979
Epoch 800 loss: 0.013087
Epoch 900 loss: 0.013066
Epoch 1000 loss: 0.011910
Epoch 1100 loss: 0.011552
Epoch 1200 loss: 0.011221
Epoch 1300 loss: 0.010859
Epoch 1400 loss: 0.010564
Epoch 1500 loss: 0.010345
Epoch 1600 loss: 0.010518
Epoch 1700 loss: 0.009886
Epoch 1800 loss: 0.009697
Epoch 1900 loss: 0.009528
Epoch 2000 loss: 0.009371
Epoch 2100 loss: 0.009229
0.009103643707931042
def plot_reconstructed_and_original_image(original_img, net, X, title=""):
"""
net: torch.nn.Module
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
= original_img.shape
num_channels, height, width eval()
net.with torch.no_grad():
= net(X)
outputs = outputs.reshape(height, width, num_channels)
outputs #outputs = outputs.permute(1, 2, 0)
= plt.figure(figsize=(6, 4))
fig = gridspec.GridSpec(1, 2, width_ratios=[1, 1])
gs
= plt.subplot(gs[0])
ax0 = plt.subplot(gs[1])
ax1
ax0.imshow(outputs.cpu())"Reconstructed Image")
ax0.set_title(
1, 2, 0))
ax1.imshow(original_img.cpu().permute("Original Image")
ax1.set_title(
for a in [ax0, ax1]:
"off")
a.axis(
=0.9)
fig.suptitle(title, y
plt.tight_layout()
"dog"]["relu"], dog_X_scaled, title="ReLU") plot_reconstructed_and_original_image(crop, nns[
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
= train(nns["dog"]["sin"], lr=3e-4, X=dog_X_sh, Y=dog_Y_sh, epochs=n_iter) imgs_dog_sin
Epoch 0 loss: 0.409131
Epoch 100 loss: 0.003556
Epoch 200 loss: 0.002283
Epoch 300 loss: 0.001715
Epoch 400 loss: 0.001379
Epoch 500 loss: 0.001163
Epoch 600 loss: 0.000990
Epoch 700 loss: 0.000902
Epoch 800 loss: 0.000778
Epoch 900 loss: 0.000713
Epoch 1000 loss: 0.000647
Epoch 1100 loss: 0.000597
Epoch 1200 loss: 0.000548
Epoch 1300 loss: 0.000509
Epoch 1400 loss: 0.000494
Epoch 1500 loss: 0.000448
Epoch 1600 loss: 0.000423
Epoch 1700 loss: 0.000401
Epoch 1800 loss: 0.000386
Epoch 1900 loss: 0.000363
Epoch 2000 loss: 0.000342
Epoch 2100 loss: 0.000335
"dog"]["sin"], dog_X_scaled, title="Sine") plot_reconstructed_and_original_image(crop, nns[
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
= {}
test_error_all_data "sin"] = torch.nn.MSELoss()(nns["dog"]["sin"](dog_X_scaled), dog_Y).item()
test_error_all_data["relu"] = torch.nn.MSELoss()(nns["dog"]["relu"](dog_X_scaled), dog_Y).item()
test_error_all_data[
test_error_all_data
{'sin': 0.00031630051671527326, 'relu': 0.009100575000047684}
= [5, 10, 100, 1000, 10000, len(dog_X)] context_lengths
# Now, reconstruct the image using a sine activation function with varying number of context points (subsampled from the original image)
= {}
test_nets_sirens
print(dog_X_sh.shape, dog_Y_sh.shape)
= {"train": {}, "test": {}}
loss_context for num_context in context_lengths:
print("="*50)
print(f"Number of context points: {num_context}")
= NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device)
test_nets_sirens[num_context] "train"][num_context] = train(test_nets_sirens[num_context],
loss_context[=3e-4, X=dog_X_sh[:num_context],
lr=dog_Y_sh[:num_context], epochs=n_iter, verbose=False)
Y# Find the loss on the test set
with torch.no_grad():
"test"][num_context] = nn.MSELoss()(test_nets_sirens[num_context](dog_X_scaled), dog_Y).item()
loss_context[print(f"Test loss: {loss_context['test'][num_context]:.6f}")
torch.Size([160000, 2]) torch.Size([160000, 3])
==================================================
Number of context points: 5
Test loss: 0.226195
==================================================
Number of context points: 10
Test loss: 0.043449
==================================================
Number of context points: 100
Test loss: 0.029126
==================================================
Number of context points: 1000
Test loss: 0.013547
==================================================
Number of context points: 10000
Test loss: 0.003966
==================================================
Number of context points: 160000
Test loss: 0.000310
# Plot the test loss vs number of context points
= pd.Series(loss_context["test"])
series ='bar')
series.plot(kind"Number of context points")
plt.xlabel("Test loss") plt.ylabel(
Text(0, 0.5, 'Test loss')
# Plot the reconstructed image for each number of context points
for num_context in context_lengths:
print(f"Number of context points: {num_context}")
plot_reconstructed_and_original_image(crop, test_nets_sirens[num_context], =f"Number of context points: {num_context}") dog_X_scaled, title
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Number of context points: 5
Number of context points: 10
Number of context points: 100
Number of context points: 1000
Number of context points: 10000
Number of context points: 160000
### Comparison of network size (number of floats) v/s the image size
= crop.shape[1]*crop.shape[2]
img_size = torch.sum(torch.tensor([p.numel() for p in nns["dog"]["sin"].parameters()]))
nw_size
print(f"Image size: {img_size}")
print(f"Network size: {nw_size}")
Image size: 160000
Network size: 50307
try:
from tabulate import tabulate
except:
%pip install tabulate
from tabulate import tabulate
= nns["dog"]["sin"]
model
= []
table_data
= 0
total_params = 0
start = {}
start_end_mapping for name, param in model.named_parameters():
= torch.prod(torch.tensor(param.shape)).item()
param_count += param_count
total_params = total_params
end
table_data.append([name, param.shape, param_count, start, end])= (start, end)
start_end_mapping[name] = end
start
print(tabulate(table_data, headers=["Layer Name", "Shape", "Parameter Count", "Start Index", "End Index"]))
print(f"Total number of parameters: {total_params}")
Layer Name Shape Parameter Count Start Index End Index
------------ ---------------------- ----------------- ------------- -----------
fc1.weight torch.Size([128, 2]) 256 0 256
fc1.bias torch.Size([128]) 128 256 384
fc2.weight torch.Size([128, 128]) 16384 384 16768
fc2.bias torch.Size([128]) 128 16768 16896
fc3.weight torch.Size([128, 128]) 16384 16896 33280
fc3.bias torch.Size([128]) 128 33280 33408
fc4.weight torch.Size([128, 128]) 16384 33408 49792
fc4.bias torch.Size([128]) 128 49792 49920
fc5.weight torch.Size([3, 128]) 384 49920 50304
fc5.bias torch.Size([3]) 3 50304 50307
Total number of parameters: 50307
start_end_mapping
{'fc1.weight': (0, 256),
'fc1.bias': (256, 384),
'fc2.weight': (384, 16768),
'fc2.bias': (16768, 16896),
'fc3.weight': (16896, 33280),
'fc3.bias': (33280, 33408),
'fc4.weight': (33408, 49792),
'fc4.bias': (49792, 49920),
'fc5.weight': (49920, 50304),
'fc5.bias': (50304, 50307)}
Input: (x, y, R, G, B)
Output: Our Hypernetwork should have the output equal to the number of parameters in the main network.
class HyperNet(nn.Module):
def __init__(self, num_layers=5, num_neurons=256, activation=torch.sin, n_out=3):
super().__init__()
self.activation = activation
self.n_out = total_params
self.fc1 = nn.Linear(5, 512)
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, total_params)
def forward(self, x):
= self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
x return self.fc3(x)
= HyperNet().to(device)
hp
= hp(torch.rand(10, 5).to(device))
out_hp print(out_hp.shape)
= out_hp.mean(dim=0)
weights_flattened print(weights_flattened.shape)
torch.Size([10, 50307])
torch.Size([50307])
# Set the weights of the model using start_end_mapping
= nns["dog"]["sin"]
model
for name, param in model.named_parameters():
= start_end_mapping[name]
start, end = weights_flattened[start:end].reshape(param.shape) param.data
Using our homegrown library (Astra)
- Flattening and unflattening the weights
- Easily creating SIREN models
- Easily creating training loops
from astra.torch.utils import ravel_pytree
= ravel_pytree(dict(model.named_parameters()))
flat_weights, unravel_fn print(flat_weights.shape)
print(unravel_fn(flat_weights))
torch.Size([50307])
{'fc1.weight': tensor([[ 7.2303e-02, 1.8475e-01],
[-6.9180e-03, 1.2231e-02],
[-1.9431e-01, 1.1563e-01],
[ 3.7423e-02, 3.5960e-02],
[ 3.5826e-02, 9.0268e-02],
[ 7.2428e-02, 2.3853e-02],
[ 6.7613e-02, -1.1572e-01],
[-1.8519e-01, -2.0865e-02],
[ 2.3674e-01, -5.3692e-02],
[ 1.3188e-01, -4.3248e-02],
[-1.7791e-01, -3.5844e-01],
[ 1.1637e-01, 1.0213e-01],
[ 3.6082e-02, -1.2981e-01],
[-8.9860e-02, 3.0743e-01],
[ 9.0076e-02, -1.3024e-01],
[ 2.0144e-01, -1.3623e-01],
[ 1.0967e-01, 1.1245e-01],
[ 2.1241e-02, 3.6728e-02],
[-7.0298e-02, 1.7884e-01],
[-7.9913e-02, -9.6210e-03],
[-1.2231e-03, -6.7778e-02],
[-6.1415e-02, 1.7982e-01],
[ 8.7729e-02, 7.6696e-02],
[ 1.1200e-02, -1.8881e-01],
[-6.9393e-02, 1.3919e-01],
[ 1.9506e-01, 1.5109e-01],
[ 1.0525e-01, 1.6728e-02],
[-7.1199e-02, -5.0646e-02],
[-3.4556e-02, -3.8549e-02],
[-6.5598e-02, 1.4262e-01],
[-6.3033e-02, 4.6765e-02],
[ 1.6343e-01, -1.1612e-01],
[ 5.3650e-02, 2.0001e-02],
[ 4.4351e-02, -1.2483e-01],
[-6.9576e-02, 1.9289e-02],
[ 2.1541e-01, -1.4095e-01],
[-4.7398e-02, -1.5326e-01],
[-5.5390e-02, -7.6401e-03],
[ 1.1270e-01, -3.3292e-02],
[ 2.6212e-01, -1.9070e-02],
[-2.2660e-02, -5.4961e-02],
[ 1.2587e-02, 1.1614e-01],
[-1.7161e-02, -7.1058e-02],
[ 2.9235e-02, -1.1581e-01],
[-1.4140e-02, -6.8351e-02],
[ 1.7271e-02, -1.7049e-01],
[-5.7903e-02, 1.6249e-01],
[-2.7371e-01, -1.5921e-01],
[ 1.2445e-01, -3.7183e-02],
[ 1.4074e-02, -2.4972e-01],
[-1.4386e-01, 1.5713e-01],
[ 1.6494e-01, 9.7524e-02],
[-1.6054e-01, 6.0998e-02],
[ 2.8658e-02, 3.1216e-02],
[-2.8942e-01, 1.5718e-01],
[ 2.3916e-01, 1.0318e-01],
[-3.3868e-02, -3.0468e-01],
[ 6.7633e-03, 6.7953e-02],
[ 3.2195e-01, -9.9608e-02],
[ 8.3436e-02, 4.2797e-02],
[-1.0098e-01, 7.9213e-02],
[-1.1259e-01, 3.0771e-02],
[ 6.9695e-02, 5.9238e-02],
[ 2.1898e-01, -2.5084e-01],
[ 8.4120e-02, 3.9265e-02],
[-4.4006e-02, -5.9050e-03],
[-1.8446e-01, 1.7159e-01],
[ 1.2318e-02, -1.6082e-02],
[-1.2422e-01, -1.2124e-01],
[ 1.5859e-01, -6.5549e-02],
[-6.1395e-02, 1.4970e-01],
[-1.4709e-01, 3.1681e-02],
[ 3.6000e-02, -1.8523e-02],
[ 1.4831e-01, -1.5797e-01],
[ 1.8725e-01, 3.8700e-02],
[-1.8377e-01, -5.8808e-02],
[ 1.0143e-02, -1.4246e-01],
[-1.4042e-01, -5.2038e-05],
[ 2.3339e-01, 8.5529e-02],
[-1.0926e-01, 9.3398e-03],
[ 2.6608e-02, 4.7951e-02],
[-1.4424e-01, -2.1638e-01],
[-2.6549e-02, 2.2673e-01],
[ 6.6978e-02, 2.1663e-01],
[ 1.6812e-01, -1.4362e-01],
[ 1.8123e-01, -1.0973e-02],
[-9.9815e-03, -6.4446e-02],
[ 1.0362e-02, -1.1157e-01],
[ 1.9818e-02, 3.1765e-03],
[-4.1822e-02, 2.9222e-02],
[ 3.5510e-04, -9.5421e-02],
[ 4.0710e-02, -3.6539e-03],
[ 8.6873e-02, 2.0872e-01],
[-1.2687e-01, 1.7474e-01],
[ 1.4681e-01, 3.6409e-02],
[ 6.9398e-02, 4.4749e-02],
[-3.4701e-01, -3.8960e-02],
[ 2.0324e-01, 6.4354e-02],
[-2.0250e-01, -7.3486e-02],
[ 2.5662e-01, -6.7429e-02],
[ 9.0947e-03, 6.0853e-03],
[-1.3005e-01, -1.3669e-01],
[-1.4868e-01, -4.5554e-03],
[ 1.5069e-01, 8.6696e-02],
[ 4.6468e-02, 6.9869e-02],
[-1.2077e-01, -3.3872e-03],
[-5.2897e-03, -1.9553e-01],
[ 1.2461e-01, -4.0207e-03],
[-1.0743e-01, 1.4018e-01],
[ 2.1903e-01, -1.1876e-01],
[ 3.3978e-02, -5.6171e-03],
[-1.6130e-01, -7.7339e-02],
[ 1.7324e-01, -9.7195e-02],
[ 4.2971e-02, 1.0253e-01],
[-1.0369e-02, -4.1130e-02],
[-1.1557e-01, -1.0589e-01],
[ 1.8067e-01, 1.6982e-01],
[ 4.1407e-02, -2.6240e-02],
[ 6.1280e-02, 2.0380e-01],
[ 2.0553e-02, -1.8692e-01],
[ 5.4052e-02, 7.0528e-03],
[ 6.3718e-02, 1.0462e-01],
[ 6.9697e-02, -7.6473e-02],
[ 7.3403e-02, 7.8244e-02],
[-1.3281e-01, -7.3169e-04],
[ 1.2163e-01, -1.2142e-01],
[-1.1956e-01, 2.6188e-02],
[-2.0094e-01, -3.3400e-03]], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>), 'fc1.bias': tensor([ 1.3937e-01, 9.1031e-02, 1.1908e-01, -5.2545e-02, -9.3397e-02,
-2.5996e-03, 5.6113e-02, 8.7695e-02, 9.4380e-02, 5.8306e-02,
-5.8491e-02, -9.4788e-02, -1.0383e-01, 1.8825e-01, 6.4387e-02,
-6.8692e-02, -5.7511e-02, -1.0626e-01, 1.8366e-01, 9.3314e-02,
-1.5770e-01, -1.1771e-01, 1.8512e-02, -1.0166e-01, 1.1778e-01,
6.4702e-03, 2.0301e-01, -2.4718e-02, 1.1473e-02, -5.9337e-02,
4.5290e-02, -2.7956e-04, 7.4153e-02, 7.4455e-02, 1.1620e-01,
9.1790e-02, -2.6550e-01, 4.0917e-02, -4.9924e-02, 3.5913e-01,
2.3945e-01, -1.3163e-01, 7.0677e-02, 4.3314e-02, 5.4965e-02,
-3.8266e-02, 1.1958e-01, 1.9800e-02, -3.2567e-02, -7.0675e-02,
7.4658e-02, 1.3528e-01, 6.7624e-02, 1.3752e-01, 1.4267e-01,
8.1409e-02, 1.5277e-01, -9.4721e-02, 9.5535e-03, -2.7910e-02,
-6.1356e-02, 9.8481e-02, 1.7242e-01, -2.3863e-02, -8.4674e-02,
-5.6868e-03, 1.4529e-01, -1.4019e-01, -2.7921e-02, 2.1329e-01,
-4.0568e-02, -1.3940e-01, 3.9279e-02, 3.4861e-02, 1.6734e-01,
-2.2712e-02, 1.0885e-01, 1.1623e-01, 1.6150e-01, -1.8630e-02,
-7.1212e-02, -1.4402e-01, 1.5908e-01, -2.2806e-01, -4.3496e-02,
1.4644e-01, -2.2250e-02, -5.0974e-02, -1.5041e-01, -1.0052e-01,
-3.0135e-02, 7.5585e-02, -5.2927e-02, -1.2286e-01, 1.0296e-02,
1.0396e-01, -1.3909e-01, 1.1614e-02, 1.0974e-03, 2.1793e-01,
-2.5740e-02, -9.2893e-02, 4.9473e-02, -1.8746e-01, 4.2226e-02,
1.7077e-01, 1.5371e-01, 1.2752e-01, -7.0501e-02, -1.9778e-03,
4.9510e-02, 2.5706e-02, -5.7692e-02, 2.4992e-01, -1.6351e-01,
6.7651e-02, -7.1454e-02, -1.0196e-01, 6.8088e-02, -1.8465e-01,
5.1298e-03, 1.0726e-02, 6.1477e-03, 2.7650e-01, -1.9581e-02,
-1.2533e-01, 1.0733e-01, -3.6380e-02], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>), 'fc2.weight': tensor([[ 0.1626, 0.1576, 0.0090, ..., 0.0867, -0.0421, -0.0025],
[ 0.1144, -0.0025, -0.1852, ..., -0.1083, -0.1402, -0.0354],
[-0.1914, -0.0494, -0.0735, ..., -0.0817, 0.0393, 0.0877],
...,
[-0.0717, -0.0291, 0.0692, ..., -0.0523, -0.0864, -0.0602],
[ 0.1178, -0.1181, 0.0901, ..., 0.0402, 0.1107, 0.0868],
[ 0.1341, 0.0799, 0.1080, ..., 0.0148, 0.0278, 0.1365]],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>), 'fc2.bias': tensor([ 2.6667e-02, -4.5061e-02, -1.3379e-01, -4.2892e-02, -1.9645e-01,
-2.3817e-01, 1.0494e-02, 2.8197e-01, 1.5406e-02, -2.4376e-02,
1.7561e-01, -1.4642e-01, -1.0797e-01, 3.6716e-02, -2.2550e-02,
1.9913e-02, -3.9965e-02, 7.6919e-02, 1.1498e-01, -6.5787e-02,
-8.4971e-02, 1.2668e-01, 2.0365e-01, 1.7946e-01, -1.2953e-01,
1.9260e-01, -9.6302e-02, -7.5193e-02, 6.4110e-02, -1.9202e-02,
1.5212e-01, 1.6503e-01, 1.4749e-01, -7.5665e-02, -2.3061e-01,
1.2124e-01, 8.7380e-02, -6.8148e-02, 1.9631e-02, 4.9709e-02,
5.2993e-02, 1.6483e-01, 1.1459e-01, -3.6108e-02, 1.9317e-01,
-4.4186e-02, -2.4532e-02, -2.2075e-01, 5.7838e-03, 1.9950e-01,
9.2932e-02, -9.3798e-02, -3.0699e-02, 3.2803e-02, 1.3139e-01,
-1.0147e-01, -7.2373e-02, 4.8918e-02, 8.4615e-02, -1.0287e-01,
-4.7278e-03, 1.2601e-01, -5.6927e-02, 2.2257e-01, -6.5842e-02,
-1.2036e-01, 9.9820e-02, 2.4243e-02, -1.0320e-02, 8.8519e-02,
1.2967e-02, 4.3215e-02, -1.6222e-01, -4.6659e-02, -5.8839e-02,
2.0254e-02, 1.3470e-01, 2.6904e-01, 8.3832e-02, 1.4001e-01,
-5.6560e-02, -1.7141e-01, 3.0431e-02, 3.0826e-02, 1.1963e-01,
-2.8436e-02, -4.2731e-02, 7.2354e-02, 1.9034e-02, -7.6484e-02,
5.2238e-02, -1.0509e-01, 9.1314e-02, 4.5333e-02, 2.1825e-01,
-1.6899e-01, -2.5924e-03, -6.3096e-02, -1.6934e-01, -1.1294e-01,
-8.2351e-02, 1.5471e-01, -6.3658e-02, 1.9206e-02, 7.1285e-02,
-1.9484e-01, 1.1433e-01, 2.1561e-01, -1.2322e-01, 5.6644e-02,
7.1328e-02, -1.2183e-01, -8.9327e-02, 1.0203e-02, 8.1200e-04,
-7.3808e-02, -2.0901e-01, -4.5195e-02, -1.3261e-04, -5.7953e-02,
6.8481e-02, -1.2712e-01, -1.2984e-01, 1.4874e-02, 4.8432e-02,
7.5876e-02, -3.8336e-02, 1.6079e-02], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>), 'fc3.weight': tensor([[ 0.0396, -0.1677, 0.0536, ..., 0.0329, 0.1299, 0.0670],
[-0.2666, -0.1038, -0.2126, ..., 0.0804, -0.1728, 0.1918],
[ 0.1665, 0.0134, -0.0402, ..., -0.0687, 0.1080, 0.0771],
...,
[ 0.1816, -0.0331, 0.0454, ..., -0.0926, 0.1632, 0.0675],
[-0.0477, 0.2884, 0.0700, ..., -0.0688, -0.2281, 0.1430],
[-0.1463, 0.1891, -0.1439, ..., 0.0600, -0.0703, -0.1880]],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>), 'fc3.bias': tensor([-1.6453e-01, 7.5851e-02, 3.7721e-02, -9.4238e-03, -7.1286e-02,
2.2447e-01, 3.8425e-02, 1.0008e-01, -1.5634e-01, 8.8691e-02,
1.2365e-02, 1.4615e-01, -5.4554e-02, -4.8148e-02, -6.4282e-02,
1.0527e-02, -3.6634e-02, 1.5034e-02, -9.3311e-03, 1.7128e-01,
-1.8236e-01, 1.3478e-01, 1.7017e-01, -7.1820e-02, 1.4992e-01,
-4.0661e-01, 9.0499e-02, -3.6053e-01, 2.5526e-01, -8.1079e-02,
-7.7242e-02, 8.8858e-02, -1.4380e-01, 2.5731e-01, 8.9234e-03,
5.1684e-02, 4.9633e-02, 1.3959e-01, -1.2853e-01, -7.9164e-02,
6.9847e-02, 1.0135e-01, -1.1189e-01, -3.8928e-02, -2.0292e-02,
1.1789e-02, 1.4188e-01, -1.8961e-01, -2.3442e-02, -2.2014e-01,
-1.2734e-01, 2.7420e-03, 1.3705e-02, -2.0665e-01, -5.5958e-03,
-7.3088e-02, 2.3228e-02, -3.1873e-01, -4.0341e-02, -1.1505e-01,
3.2668e-02, 8.3553e-02, 8.1658e-03, -1.6962e-01, -1.2170e-01,
1.0409e-01, -2.6209e-01, -9.5062e-02, 8.3047e-02, 1.3676e-02,
-1.9911e-02, -2.3725e-01, -1.3274e-01, -9.1445e-03, 1.6600e-01,
-1.8830e-02, 9.9199e-03, 1.3318e-01, 8.1056e-02, -1.8402e-02,
-5.7780e-05, 1.6435e-01, 1.0902e-01, 3.4496e-02, -1.5712e-01,
-9.2133e-03, 1.7642e-01, 9.4327e-02, -7.3375e-02, -1.2195e-02,
4.0908e-02, -1.3142e-01, 6.7244e-02, 2.5039e-02, -4.3484e-02,
-8.8210e-02, -1.8980e-02, -2.9393e-02, 2.2965e-01, 6.2888e-02,
1.8778e-01, 2.0747e-01, -2.7634e-01, 8.7129e-02, 1.3946e-01,
-2.0265e-01, 2.0881e-01, 1.2091e-01, 3.0196e-02, 2.2240e-02,
1.3639e-01, 1.6089e-01, -8.9004e-02, -3.3855e-03, 5.0107e-02,
1.5593e-01, 6.7469e-02, 6.8105e-02, 8.7053e-02, 2.0969e-01,
1.1700e-01, -1.6770e-02, 4.2229e-02, 4.6535e-02, -9.8295e-02,
1.9084e-01, -4.1850e-02, -4.6306e-02], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>), 'fc4.weight': tensor([[ 0.2830, -0.1100, -0.1323, ..., -0.3271, 0.0813, -0.0061],
[ 0.2710, 0.3222, 0.0848, ..., 0.2513, -0.0644, -0.2402],
[-0.0105, 0.0584, 0.0086, ..., 0.0149, 0.0471, 0.1521],
...,
[ 0.1835, 0.0269, -0.0054, ..., -0.2250, 0.0190, -0.0436],
[ 0.0104, -0.1079, 0.0619, ..., -0.0457, -0.1928, 0.1140],
[ 0.0292, 0.1185, 0.0900, ..., 0.0677, 0.1507, -0.2381]],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>), 'fc4.bias': tensor([-1.6619e-01, -1.1489e-01, -1.0995e-02, -1.8886e-01, 8.1446e-02,
-1.7253e-01, 4.6207e-02, -1.2412e-01, 2.5603e-01, -2.4989e-01,
1.0106e-01, 1.4181e-01, 1.1908e-01, -1.3170e-01, -1.6666e-02,
6.9981e-02, -9.4864e-02, 2.9104e-02, 3.3937e-02, 8.8687e-02,
-2.1460e-01, 1.7662e-02, -7.4605e-02, -2.8766e-01, 1.4387e-03,
8.8849e-02, -2.6133e-01, 2.3944e-01, 4.2308e-01, -2.2386e-02,
-3.5002e-02, 1.5050e-01, -1.8755e-01, 4.0275e-02, 5.0206e-02,
2.8535e-02, 1.5281e-01, 1.5107e-01, 1.1838e-02, -1.9254e-01,
-1.6728e-01, 1.3601e-01, -9.4265e-02, -3.5805e-02, -1.3126e-01,
-7.3716e-02, 4.6302e-02, -5.5326e-02, 8.6940e-02, -2.5061e-01,
7.0305e-02, 7.3881e-02, -1.4083e-02, 1.8914e-01, 3.1738e-02,
-2.8057e-02, 1.0428e-01, -4.5615e-02, 7.6921e-03, -1.8016e-01,
8.4051e-05, -1.8171e-01, -1.7909e-02, -1.4178e-02, 5.4015e-02,
2.5439e-01, -2.7733e-01, 2.3340e-01, -1.2361e-01, -2.2032e-01,
-8.9306e-02, 6.3847e-02, -7.4009e-02, -4.2775e-02, -1.4546e-01,
5.9691e-02, -2.3814e-01, -4.3989e-02, -1.7065e-01, 2.3011e-01,
-1.9524e-02, -8.2356e-02, 1.6535e-01, 1.1677e-01, -1.7481e-01,
-2.4286e-02, -5.4737e-02, -9.8581e-03, 5.8233e-02, -2.0740e-02,
3.7968e-02, -5.8098e-02, 3.2286e-02, -4.9859e-02, 1.3212e-02,
-1.4157e-01, 3.6066e-02, -7.8502e-03, -1.1150e-01, 1.7404e-01,
2.9603e-02, -6.5617e-02, 5.7098e-02, 4.4060e-02, 1.1591e-01,
-1.1777e-01, -2.4960e-01, -2.0873e-01, -5.8239e-02, -1.9206e-01,
5.1293e-02, -1.4662e-01, -1.0294e-01, -1.4170e-01, 1.0213e-01,
-1.0966e-01, 6.4190e-02, -1.2977e-02, 1.3461e-01, 9.0853e-02,
1.3335e-01, -1.0664e-01, 2.0796e-01, -1.0372e-01, 1.2097e-02,
-1.0545e-01, 4.7369e-02, 4.7679e-02], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>), 'fc5.weight': tensor([[ 3.7005e-02, 1.7218e-02, -6.0654e-02, 1.9607e-01, 1.8402e-01,
7.0345e-02, 1.0801e-01, 2.0575e-02, -6.2460e-02, -2.6926e-02,
3.2469e-02, -5.7114e-02, 2.4663e-02, -8.2723e-02, 1.5335e-01,
2.5052e-01, 1.7047e-01, 3.2905e-02, 4.4214e-03, 1.0149e-01,
2.4933e-01, 1.0568e-01, -9.2186e-02, 1.0348e-01, 9.7047e-02,
2.4081e-02, 4.2537e-02, -1.3451e-01, -9.5488e-02, 2.5056e-02,
8.5907e-02, -3.5518e-01, -1.8330e-01, -2.0032e-01, 5.3350e-03,
8.3568e-02, -7.8669e-02, -4.0999e-02, -1.6546e-01, 7.6125e-02,
4.2919e-02, -1.1224e-01, 6.6079e-02, 7.5957e-02, 1.6435e-01,
-6.5331e-02, -1.4222e-01, -2.3253e-02, -2.4421e-02, -4.7380e-02,
-3.1845e-02, 1.4371e-01, 1.6564e-02, -9.1425e-02, -2.3566e-01,
2.4497e-02, -1.0897e-02, -6.5980e-02, -6.4120e-02, -4.8469e-02,
1.9976e-01, 7.1650e-02, -3.1568e-02, -7.8932e-02, -1.0762e-02,
-5.2840e-02, 1.5349e-01, -1.6643e-01, -1.1844e-01, 1.0926e-01,
-2.5987e-02, -2.4647e-01, 1.2376e-01, 1.1220e-02, -8.4029e-02,
-9.8139e-02, 5.3532e-02, -1.4702e-02, 1.4518e-01, 3.3054e-02,
-1.3509e-01, -1.8362e-01, 4.7140e-02, 7.5887e-03, 4.8157e-02,
1.6827e-01, 1.7298e-01, -1.3073e-01, 2.5464e-01, 8.3428e-02,
-1.0881e-01, 1.6493e-01, 2.8744e-03, 7.8934e-03, 8.0214e-02,
-4.2423e-02, 1.0871e-01, 6.9979e-02, -8.5941e-02, 1.5277e-01,
1.2681e-06, -1.6246e-01, -4.9869e-02, -7.0287e-02, -1.6674e-03,
-1.2628e-01, 7.2879e-02, -2.6658e-01, -4.3385e-02, -1.3041e-02,
1.3853e-01, 9.2652e-02, 1.3920e-01, 1.6977e-01, -8.1156e-02,
3.3611e-02, 1.9832e-02, 3.4340e-02, -1.0543e-02, -5.9265e-02,
-1.3928e-01, -1.9022e-02, 1.0770e-01, 6.6137e-02, -1.3821e-01,
-8.6274e-02, -1.2282e-01, -4.3384e-02],
[ 1.5446e-01, -1.2482e-01, -3.1733e-01, -5.0405e-02, -9.8563e-02,
1.4524e-02, -4.7723e-02, -2.1366e-01, 1.2520e-02, 9.3535e-02,
-1.7930e-02, 8.1626e-03, -7.2552e-02, -1.7400e-01, -4.8571e-03,
3.4715e-02, -1.6996e-01, 2.5612e-02, 1.7753e-02, 1.1491e-01,
3.2849e-02, 7.4169e-02, 2.3232e-02, 2.3961e-01, -1.4371e-01,
7.5394e-02, -5.6723e-03, 1.4591e-02, -1.0227e-01, -6.3025e-02,
-4.8187e-02, 3.7513e-02, -2.2861e-02, -2.5060e-02, -5.2416e-02,
3.6575e-02, 4.3657e-02, 3.9590e-02, 6.8889e-02, 7.8860e-02,
-6.1261e-02, 6.4804e-03, 6.4698e-02, -1.3405e-01, 4.9579e-02,
9.1552e-02, 2.5240e-01, -1.1347e-01, 4.2408e-02, 8.2981e-02,
1.1758e-01, 1.4038e-02, -5.8177e-02, -4.6107e-02, 9.8306e-02,
5.3295e-02, -1.2327e-02, -3.7325e-02, 2.1046e-01, 1.1949e-01,
-8.8206e-03, 9.5156e-02, 8.3061e-03, 8.0187e-02, 4.3940e-02,
1.4411e-01, -5.8862e-03, 1.5729e-02, -1.7983e-01, 5.2839e-02,
5.6460e-02, -1.3105e-01, 2.7025e-01, 4.4835e-02, -1.3793e-01,
4.7627e-02, -3.9222e-02, -6.4474e-02, 1.2132e-01, -5.4602e-02,
-4.1981e-02, -1.7641e-01, 1.1443e-01, 1.6024e-01, -4.2657e-02,
2.9715e-02, -6.2705e-03, 7.2440e-02, -3.1999e-01, -1.6657e-01,
-9.7144e-02, 7.7817e-02, -4.3598e-02, 1.5082e-03, 5.5090e-02,
-3.5226e-02, 7.4446e-02, -1.4492e-01, 1.9284e-01, -1.0586e-01,
9.7238e-02, -2.9331e-01, 1.6685e-01, 6.7393e-02, -3.8107e-02,
-5.1379e-03, -1.8144e-01, -1.0758e-01, -3.5332e-02, -9.5876e-02,
-3.3853e-02, 7.4512e-02, 7.1593e-02, 5.6300e-02, -2.2599e-01,
-2.6461e-02, -9.8212e-02, -1.3220e-01, 2.4089e-03, -2.6224e-01,
-4.1655e-02, -1.2883e-03, -1.4610e-01, 7.1842e-02, 1.6349e-01,
-3.0933e-01, 2.8165e-02, 4.9211e-02],
[ 1.6185e-02, -4.3435e-02, 2.9302e-01, 6.8762e-02, -1.0513e-01,
8.0041e-03, 3.1691e-02, 7.8240e-02, -3.0865e-01, -1.2082e-01,
-1.5014e-02, -6.4405e-02, 3.7052e-03, 1.5633e-01, -8.7142e-02,
2.7108e-01, -4.8148e-02, 1.9585e-01, -1.0857e-01, -1.6191e-02,
1.1271e-01, -4.2816e-02, 1.1897e-01, 2.1357e-01, 6.0726e-03,
-1.6058e-01, 1.0014e-01, 7.2932e-03, -3.3023e-02, -1.3890e-02,
2.3372e-01, -1.5602e-01, -1.3258e-01, 3.4258e-02, 4.1946e-03,
-4.5052e-02, 1.0765e-01, 8.8289e-02, 5.0133e-02, 1.0398e-01,
1.7352e-02, -1.1756e-01, 1.2870e-01, 9.9371e-02, 1.0367e-01,
-7.6240e-02, -6.7554e-02, 2.2970e-02, 1.8892e-01, -8.2191e-02,
1.6877e-01, -7.5892e-03, 4.1152e-02, 6.9047e-02, 1.0306e-01,
2.3320e-01, 1.2969e-01, -2.9587e-01, 2.2722e-01, -1.7229e-01,
-7.0078e-02, 6.9529e-02, 1.5175e-01, 4.4102e-02, -1.8311e-01,
4.8507e-03, -1.8906e-01, 1.3765e-01, 8.4732e-02, -8.0962e-02,
-6.6467e-02, -4.6670e-02, -1.5538e-02, -2.2635e-01, 2.1714e-01,
2.3420e-01, -1.2549e-01, 1.6654e-02, 7.8636e-02, -5.5591e-02,
1.5804e-01, -1.9898e-03, -1.6906e-01, -1.9567e-02, 3.1766e-02,
9.4322e-03, -7.3489e-02, -2.7695e-02, 1.2616e-01, 1.7542e-01,
1.9282e-01, -1.0990e-01, -2.4401e-01, 7.7423e-02, -1.5268e-01,
-3.8916e-03, 6.2967e-02, 1.1859e-01, 8.4724e-02, -1.8199e-01,
-7.2396e-02, -6.9568e-02, -7.0842e-02, -9.4516e-02, 2.1872e-01,
2.2665e-01, -1.6960e-01, -2.2812e-01, 1.2222e-01, 1.3715e-01,
-4.3735e-02, -7.3953e-02, -9.7945e-02, -2.6904e-02, 8.7878e-02,
-1.0421e-02, 1.6984e-02, 3.8456e-02, 1.7513e-01, -1.0124e-03,
3.8656e-02, 1.4228e-01, 9.2844e-03, 2.9087e-02, 5.2339e-02,
-1.9062e-01, -1.0616e-01, 1.0397e-01]], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>), 'fc5.bias': tensor([-0.0278, -0.0784, -0.0022], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>)}
unravel_fn(weights_flattened)
{'fc1.weight': tensor([[ 8.4120e-02, 3.9265e-02],
[-4.4006e-02, -5.9050e-03],
[-1.8446e-01, 1.7159e-01],
[ 1.2318e-02, -1.6082e-02],
[-1.2422e-01, -1.2124e-01],
[ 1.5859e-01, -6.5549e-02],
[-6.1395e-02, 1.4970e-01],
[-1.4709e-01, 3.1681e-02],
[ 3.6000e-02, -1.8523e-02],
[ 1.4831e-01, -1.5797e-01],
[ 1.8725e-01, 3.8700e-02],
[-1.8377e-01, -5.8808e-02],
[ 1.0143e-02, -1.4246e-01],
[-1.4042e-01, -5.2038e-05],
[ 2.3339e-01, 8.5529e-02],
[-1.0926e-01, 9.3398e-03],
[ 2.6608e-02, 4.7951e-02],
[-1.4424e-01, -2.1638e-01],
[-2.6549e-02, 2.2673e-01],
[ 6.6978e-02, 2.1663e-01],
[ 1.6812e-01, -1.4362e-01],
[ 1.8123e-01, -1.0973e-02],
[-9.9815e-03, -6.4446e-02],
[ 1.0362e-02, -1.1157e-01],
[ 1.9818e-02, 3.1765e-03],
[-4.1822e-02, 2.9222e-02],
[ 3.5510e-04, -9.5421e-02],
[ 4.0710e-02, -3.6539e-03],
[ 8.6873e-02, 2.0872e-01],
[-1.2687e-01, 1.7474e-01],
[ 1.4681e-01, 3.6409e-02],
[ 6.9398e-02, 4.4749e-02],
[-3.4701e-01, -3.8960e-02],
[ 2.0324e-01, 6.4354e-02],
[-2.0250e-01, -7.3486e-02],
[ 2.5662e-01, -6.7429e-02],
[ 9.0947e-03, 6.0853e-03],
[-1.3005e-01, -1.3669e-01],
[-1.4868e-01, -4.5554e-03],
[ 1.5069e-01, 8.6696e-02],
[ 4.6468e-02, 6.9869e-02],
[-1.2077e-01, -3.3872e-03],
[-5.2897e-03, -1.9553e-01],
[ 1.2461e-01, -4.0207e-03],
[-1.0743e-01, 1.4018e-01],
[ 2.1903e-01, -1.1876e-01],
[ 3.3978e-02, -5.6171e-03],
[-1.6130e-01, -7.7339e-02],
[ 1.7324e-01, -9.7195e-02],
[ 4.2971e-02, 1.0253e-01],
[-1.0369e-02, -4.1130e-02],
[-1.1557e-01, -1.0589e-01],
[ 1.8067e-01, 1.6982e-01],
[ 4.1407e-02, -2.6240e-02],
[ 6.1280e-02, 2.0380e-01],
[ 2.0553e-02, -1.8692e-01],
[ 5.4052e-02, 7.0528e-03],
[ 6.3718e-02, 1.0462e-01],
[ 6.9697e-02, -7.6473e-02],
[ 7.3403e-02, 7.8244e-02],
[-1.3281e-01, -7.3169e-04],
[ 1.2163e-01, -1.2142e-01],
[-1.1956e-01, 2.6188e-02],
[-2.0094e-01, -3.3400e-03],
[ 1.3937e-01, 9.1031e-02],
[ 1.1908e-01, -5.2545e-02],
[-9.3397e-02, -2.5996e-03],
[ 5.6113e-02, 8.7695e-02],
[ 9.4380e-02, 5.8306e-02],
[-5.8491e-02, -9.4788e-02],
[-1.0383e-01, 1.8825e-01],
[ 6.4387e-02, -6.8692e-02],
[-5.7511e-02, -1.0626e-01],
[ 1.8366e-01, 9.3314e-02],
[-1.5770e-01, -1.1771e-01],
[ 1.8512e-02, -1.0166e-01],
[ 1.1778e-01, 6.4702e-03],
[ 2.0301e-01, -2.4718e-02],
[ 1.1473e-02, -5.9337e-02],
[ 4.5290e-02, -2.7956e-04],
[ 7.4153e-02, 7.4455e-02],
[ 1.1620e-01, 9.1790e-02],
[-2.6550e-01, 4.0917e-02],
[-4.9924e-02, 3.5913e-01],
[ 2.3945e-01, -1.3163e-01],
[ 7.0677e-02, 4.3314e-02],
[ 5.4965e-02, -3.8266e-02],
[ 1.1958e-01, 1.9800e-02],
[-3.2567e-02, -7.0675e-02],
[ 7.4658e-02, 1.3528e-01],
[ 6.7624e-02, 1.3752e-01],
[ 1.4267e-01, 8.1409e-02],
[ 1.5277e-01, -9.4721e-02],
[ 9.5535e-03, -2.7910e-02],
[-6.1356e-02, 9.8481e-02],
[ 1.7242e-01, -2.3863e-02],
[-8.4674e-02, -5.6868e-03],
[ 1.4529e-01, -1.4019e-01],
[-2.7921e-02, 2.1329e-01],
[-4.0568e-02, -1.3940e-01],
[ 3.9279e-02, 3.4861e-02],
[ 1.6734e-01, -2.2712e-02],
[ 1.0885e-01, 1.1623e-01],
[ 1.6150e-01, -1.8630e-02],
[-7.1212e-02, -1.4402e-01],
[ 1.5908e-01, -2.2806e-01],
[-4.3496e-02, 1.4644e-01],
[-2.2250e-02, -5.0974e-02],
[-1.5041e-01, -1.0052e-01],
[-3.0135e-02, 7.5585e-02],
[-5.2927e-02, -1.2286e-01],
[ 1.0296e-02, 1.0396e-01],
[-1.3909e-01, 1.1614e-02],
[ 1.0974e-03, 2.1793e-01],
[-2.5740e-02, -9.2893e-02],
[ 4.9473e-02, -1.8746e-01],
[ 4.2226e-02, 1.7077e-01],
[ 1.5371e-01, 1.2752e-01],
[-7.0501e-02, -1.9778e-03],
[ 4.9510e-02, 2.5706e-02],
[-5.7692e-02, 2.4992e-01],
[-1.6351e-01, 6.7651e-02],
[-7.1454e-02, -1.0196e-01],
[ 6.8088e-02, -1.8465e-01],
[ 5.1298e-03, 1.0726e-02],
[ 6.1477e-03, 2.7650e-01],
[-1.9581e-02, -1.2533e-01],
[ 1.0733e-01, -3.6380e-02]], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>),
'fc1.bias': tensor([ 0.0723, 0.1847, -0.0069, 0.0122, -0.1943, 0.1156, 0.0374, 0.0360,
0.0358, 0.0903, 0.0724, 0.0239, 0.0676, -0.1157, -0.1852, -0.0209,
0.2367, -0.0537, 0.1319, -0.0432, -0.1779, -0.3584, 0.1164, 0.1021,
0.0361, -0.1298, -0.0899, 0.3074, 0.0901, -0.1302, 0.2014, -0.1362,
0.1097, 0.1125, 0.0212, 0.0367, -0.0703, 0.1788, -0.0799, -0.0096,
-0.0012, -0.0678, -0.0614, 0.1798, 0.0877, 0.0767, 0.0112, -0.1888,
-0.0694, 0.1392, 0.1951, 0.1511, 0.1052, 0.0167, -0.0712, -0.0506,
-0.0346, -0.0385, -0.0656, 0.1426, -0.0630, 0.0468, 0.1634, -0.1161,
0.0537, 0.0200, 0.0444, -0.1248, -0.0696, 0.0193, 0.2154, -0.1409,
-0.0474, -0.1533, -0.0554, -0.0076, 0.1127, -0.0333, 0.2621, -0.0191,
-0.0227, -0.0550, 0.0126, 0.1161, -0.0172, -0.0711, 0.0292, -0.1158,
-0.0141, -0.0684, 0.0173, -0.1705, -0.0579, 0.1625, -0.2737, -0.1592,
0.1245, -0.0372, 0.0141, -0.2497, -0.1439, 0.1571, 0.1649, 0.0975,
-0.1605, 0.0610, 0.0287, 0.0312, -0.2894, 0.1572, 0.2392, 0.1032,
-0.0339, -0.3047, 0.0068, 0.0680, 0.3220, -0.0996, 0.0834, 0.0428,
-0.1010, 0.0792, -0.1126, 0.0308, 0.0697, 0.0592, 0.2190, -0.2508],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc2.weight': tensor([[ 0.1144, -0.0025, -0.1852, ..., -0.1083, -0.1402, -0.0354],
[-0.1914, -0.0494, -0.0735, ..., -0.0817, 0.0393, 0.0877],
[-0.2834, 0.0507, 0.0576, ..., 0.0827, -0.0266, 0.1160],
...,
[ 0.1178, -0.1181, 0.0901, ..., 0.0402, 0.1107, 0.0868],
[ 0.1341, 0.0799, 0.1080, ..., 0.0148, 0.0278, 0.1365],
[ 0.0267, -0.0451, -0.1338, ..., 0.0759, -0.0383, 0.0161]],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc2.bias': tensor([ 0.1626, 0.1576, 0.0090, 0.2106, -0.2146, 0.0891, 0.0409, -0.0352,
0.1271, 0.1234, -0.2810, 0.0872, 0.0070, 0.0727, 0.1334, -0.0895,
0.0315, -0.0023, -0.0810, -0.1109, -0.2472, -0.1886, -0.1106, -0.0783,
0.0280, -0.0559, -0.1411, -0.1457, -0.1365, -0.1305, 0.0753, 0.2940,
0.0084, 0.0858, -0.0895, -0.1531, -0.2004, -0.0533, -0.0228, -0.0384,
-0.0674, 0.1116, 0.1309, -0.0194, -0.2964, 0.0495, 0.0392, 0.0084,
-0.3178, -0.0079, -0.0193, -0.0915, 0.0934, 0.0513, 0.1537, -0.1062,
-0.0009, 0.1223, 0.1528, -0.1958, -0.0226, 0.1976, 0.2536, -0.0666,
-0.0249, 0.0364, 0.1468, -0.1163, -0.1548, -0.2458, -0.0641, -0.2811,
0.1335, -0.0726, -0.0971, 0.0742, 0.0757, 0.1187, -0.0358, 0.1663,
-0.0367, -0.1590, -0.1435, 0.0068, -0.1081, -0.0447, 0.0111, 0.1300,
0.0330, -0.1142, -0.0675, 0.0358, -0.0160, 0.1685, -0.1855, 0.0587,
-0.0684, -0.0809, 0.2065, -0.0541, -0.0810, -0.2434, -0.0082, 0.1038,
-0.1389, 0.0092, -0.0100, 0.1909, -0.0241, 0.2466, 0.2520, -0.0404,
-0.0489, -0.0476, 0.0276, -0.1169, 0.0349, 0.1722, 0.1544, 0.1706,
0.1735, -0.1790, -0.1286, -0.0445, 0.0327, 0.0867, -0.0421, -0.0025],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc3.weight': tensor([[-0.2666, -0.1038, -0.2126, ..., 0.0804, -0.1728, 0.1918],
[ 0.1665, 0.0134, -0.0402, ..., -0.0687, 0.1080, 0.0771],
[ 0.0048, -0.0303, 0.0350, ..., -0.0155, -0.0594, -0.0849],
...,
[-0.0477, 0.2884, 0.0700, ..., -0.0688, -0.2281, 0.1430],
[-0.1463, 0.1891, -0.1439, ..., 0.0600, -0.0703, -0.1880],
[-0.1645, 0.0759, 0.0377, ..., 0.1908, -0.0419, -0.0463]],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc3.bias': tensor([ 0.0396, -0.1677, 0.0536, -0.0781, 0.0929, 0.1182, -0.0727, -0.2348,
-0.0364, 0.0144, -0.0160, -0.1700, -0.1743, 0.1426, 0.0120, 0.0539,
-0.0199, 0.0310, 0.0724, 0.1968, -0.0554, -0.1410, 0.0416, 0.0742,
0.0453, 0.0107, 0.0289, -0.0609, 0.1151, -0.0019, 0.1392, -0.1715,
-0.1035, -0.1238, -0.2124, 0.1473, -0.1585, 0.0812, -0.0525, -0.1321,
-0.0063, 0.1354, 0.1208, 0.1932, -0.0749, 0.0256, 0.3203, -0.1646,
-0.1285, -0.0764, 0.0970, -0.0485, -0.0683, 0.1397, 0.0028, 0.0345,
-0.0697, 0.0072, -0.0970, -0.1684, -0.1472, -0.0457, -0.0492, -0.0625,
-0.0167, -0.1218, 0.1252, 0.0084, 0.0421, 0.0914, 0.0664, 0.1713,
0.1204, 0.1472, -0.1081, 0.0435, -0.0677, -0.0565, 0.0609, 0.0308,
-0.0063, -0.0176, 0.0553, 0.0903, 0.1117, 0.0311, -0.0477, -0.0769,
0.0113, 0.0810, 0.0353, 0.1450, 0.0064, -0.1846, -0.0507, -0.0900,
-0.1156, 0.0205, 0.2363, -0.0141, -0.2917, -0.1116, -0.0367, 0.0775,
-0.1252, 0.1363, -0.1238, -0.1238, 0.0410, 0.1691, 0.0815, -0.1077,
0.0502, -0.1703, -0.1145, 0.0335, -0.0765, -0.1086, 0.0892, 0.2044,
0.1037, 0.1099, 0.2079, -0.0267, 0.1770, 0.0329, 0.1299, 0.0670],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc4.weight': tensor([[ 0.2710, 0.3222, 0.0848, ..., 0.2513, -0.0644, -0.2402],
[-0.0105, 0.0584, 0.0086, ..., 0.0149, 0.0471, 0.1521],
[ 0.0950, -0.1262, -0.2527, ..., -0.0147, -0.0260, -0.0571],
...,
[ 0.0104, -0.1079, 0.0619, ..., -0.0457, -0.1928, 0.1140],
[ 0.0292, 0.1185, 0.0900, ..., 0.0677, 0.1507, -0.2381],
[-0.1662, -0.1149, -0.0110, ..., -0.1055, 0.0474, 0.0477]],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc4.bias': tensor([ 0.2830, -0.1100, -0.1323, 0.0350, 0.0379, 0.1048, 0.0118, -0.0220,
-0.1217, 0.0384, 0.0729, 0.0511, 0.0749, 0.0740, -0.2084, 0.1948,
0.1429, -0.0998, 0.0815, -0.1628, 0.0771, 0.0368, 0.2225, -0.0697,
-0.0751, 0.0843, -0.1492, 0.0615, -0.0533, 0.0109, -0.1530, 0.0756,
-0.2430, 0.0396, 0.0362, -0.2492, -0.1522, 0.0886, -0.0118, -0.0884,
0.1067, 0.0107, 0.1849, 0.0924, 0.1142, -0.0098, 0.0068, -0.0713,
-0.0696, 0.0380, -0.1443, -0.1019, -0.0137, -0.1282, 0.0116, 0.2317,
-0.0611, 0.0513, -0.0315, 0.0154, 0.1059, 0.1193, 0.0808, -0.0066,
-0.2554, -0.0068, 0.1554, -0.1032, -0.0570, -0.0236, 0.0222, 0.1330,
-0.0821, -0.2539, -0.1511, -0.0738, -0.0669, -0.2200, 0.1576, -0.1148,
-0.0310, -0.0278, 0.0835, -0.0322, -0.0552, -0.0263, 0.1971, 0.3168,
-0.0859, 0.0482, -0.1642, 0.0771, -0.1242, -0.0020, 0.1226, -0.0466,
0.0177, -0.1199, 0.0924, 0.0138, -0.0817, -0.0337, -0.0027, 0.0997,
-0.0822, 0.0004, 0.0043, -0.0923, -0.0192, -0.0081, -0.1371, -0.0770,
-0.1474, 0.0268, 0.0800, -0.0389, 0.0587, 0.0421, 0.1361, 0.0970,
0.0347, -0.1867, -0.0402, -0.1091, -0.1063, -0.3271, 0.0813, -0.0061],
device='cuda:0', grad_fn=<ReshapeAliasBackward0>),
'fc5.weight': tensor([[ 1.9607e-01, 1.8402e-01, 7.0345e-02, 1.0801e-01, 2.0575e-02,
-6.2460e-02, -2.6926e-02, 3.2469e-02, -5.7114e-02, 2.4663e-02,
-8.2723e-02, 1.5335e-01, 2.5052e-01, 1.7047e-01, 3.2905e-02,
4.4214e-03, 1.0149e-01, 2.4933e-01, 1.0568e-01, -9.2186e-02,
1.0348e-01, 9.7047e-02, 2.4081e-02, 4.2537e-02, -1.3451e-01,
-9.5488e-02, 2.5056e-02, 8.5907e-02, -3.5518e-01, -1.8330e-01,
-2.0032e-01, 5.3350e-03, 8.3568e-02, -7.8669e-02, -4.0999e-02,
-1.6546e-01, 7.6125e-02, 4.2919e-02, -1.1224e-01, 6.6079e-02,
7.5957e-02, 1.6435e-01, -6.5331e-02, -1.4222e-01, -2.3253e-02,
-2.4421e-02, -4.7380e-02, -3.1845e-02, 1.4371e-01, 1.6564e-02,
-9.1425e-02, -2.3566e-01, 2.4497e-02, -1.0897e-02, -6.5980e-02,
-6.4120e-02, -4.8469e-02, 1.9976e-01, 7.1650e-02, -3.1568e-02,
-7.8932e-02, -1.0762e-02, -5.2840e-02, 1.5349e-01, -1.6643e-01,
-1.1844e-01, 1.0926e-01, -2.5987e-02, -2.4647e-01, 1.2376e-01,
1.1220e-02, -8.4029e-02, -9.8139e-02, 5.3532e-02, -1.4702e-02,
1.4518e-01, 3.3054e-02, -1.3509e-01, -1.8362e-01, 4.7140e-02,
7.5887e-03, 4.8157e-02, 1.6827e-01, 1.7298e-01, -1.3073e-01,
2.5464e-01, 8.3428e-02, -1.0881e-01, 1.6493e-01, 2.8744e-03,
7.8934e-03, 8.0214e-02, -4.2423e-02, 1.0871e-01, 6.9979e-02,
-8.5941e-02, 1.5277e-01, 1.2681e-06, -1.6246e-01, -4.9869e-02,
-7.0287e-02, -1.6674e-03, -1.2628e-01, 7.2879e-02, -2.6658e-01,
-4.3385e-02, -1.3041e-02, 1.3853e-01, 9.2652e-02, 1.3920e-01,
1.6977e-01, -8.1156e-02, 3.3611e-02, 1.9832e-02, 3.4340e-02,
-1.0543e-02, -5.9265e-02, -1.3928e-01, -1.9022e-02, 1.0770e-01,
6.6137e-02, -1.3821e-01, -8.6274e-02, -1.2282e-01, -4.3384e-02,
1.5446e-01, -1.2482e-01, -3.1733e-01],
[-5.0405e-02, -9.8563e-02, 1.4524e-02, -4.7723e-02, -2.1366e-01,
1.2520e-02, 9.3535e-02, -1.7930e-02, 8.1626e-03, -7.2552e-02,
-1.7400e-01, -4.8571e-03, 3.4715e-02, -1.6996e-01, 2.5612e-02,
1.7753e-02, 1.1491e-01, 3.2849e-02, 7.4169e-02, 2.3232e-02,
2.3961e-01, -1.4371e-01, 7.5394e-02, -5.6723e-03, 1.4591e-02,
-1.0227e-01, -6.3025e-02, -4.8187e-02, 3.7513e-02, -2.2861e-02,
-2.5060e-02, -5.2416e-02, 3.6575e-02, 4.3657e-02, 3.9590e-02,
6.8889e-02, 7.8860e-02, -6.1261e-02, 6.4804e-03, 6.4698e-02,
-1.3405e-01, 4.9579e-02, 9.1552e-02, 2.5240e-01, -1.1347e-01,
4.2408e-02, 8.2981e-02, 1.1758e-01, 1.4038e-02, -5.8177e-02,
-4.6107e-02, 9.8306e-02, 5.3295e-02, -1.2327e-02, -3.7325e-02,
2.1046e-01, 1.1949e-01, -8.8206e-03, 9.5156e-02, 8.3061e-03,
8.0187e-02, 4.3940e-02, 1.4411e-01, -5.8862e-03, 1.5729e-02,
-1.7983e-01, 5.2839e-02, 5.6460e-02, -1.3105e-01, 2.7025e-01,
4.4835e-02, -1.3793e-01, 4.7627e-02, -3.9222e-02, -6.4474e-02,
1.2132e-01, -5.4602e-02, -4.1981e-02, -1.7641e-01, 1.1443e-01,
1.6024e-01, -4.2657e-02, 2.9715e-02, -6.2705e-03, 7.2440e-02,
-3.1999e-01, -1.6657e-01, -9.7144e-02, 7.7817e-02, -4.3598e-02,
1.5082e-03, 5.5090e-02, -3.5226e-02, 7.4446e-02, -1.4492e-01,
1.9284e-01, -1.0586e-01, 9.7238e-02, -2.9331e-01, 1.6685e-01,
6.7393e-02, -3.8107e-02, -5.1379e-03, -1.8144e-01, -1.0758e-01,
-3.5332e-02, -9.5876e-02, -3.3853e-02, 7.4512e-02, 7.1593e-02,
5.6300e-02, -2.2599e-01, -2.6461e-02, -9.8212e-02, -1.3220e-01,
2.4089e-03, -2.6224e-01, -4.1655e-02, -1.2883e-03, -1.4610e-01,
7.1842e-02, 1.6349e-01, -3.0933e-01, 2.8165e-02, 4.9211e-02,
1.6185e-02, -4.3435e-02, 2.9302e-01],
[ 6.8762e-02, -1.0513e-01, 8.0041e-03, 3.1691e-02, 7.8240e-02,
-3.0865e-01, -1.2082e-01, -1.5014e-02, -6.4405e-02, 3.7052e-03,
1.5633e-01, -8.7142e-02, 2.7108e-01, -4.8148e-02, 1.9585e-01,
-1.0857e-01, -1.6191e-02, 1.1271e-01, -4.2816e-02, 1.1897e-01,
2.1357e-01, 6.0726e-03, -1.6058e-01, 1.0014e-01, 7.2932e-03,
-3.3023e-02, -1.3890e-02, 2.3372e-01, -1.5602e-01, -1.3258e-01,
3.4258e-02, 4.1946e-03, -4.5052e-02, 1.0765e-01, 8.8289e-02,
5.0133e-02, 1.0398e-01, 1.7352e-02, -1.1756e-01, 1.2870e-01,
9.9371e-02, 1.0367e-01, -7.6240e-02, -6.7554e-02, 2.2970e-02,
1.8892e-01, -8.2191e-02, 1.6877e-01, -7.5892e-03, 4.1152e-02,
6.9047e-02, 1.0306e-01, 2.3320e-01, 1.2969e-01, -2.9587e-01,
2.2722e-01, -1.7229e-01, -7.0078e-02, 6.9529e-02, 1.5175e-01,
4.4102e-02, -1.8311e-01, 4.8507e-03, -1.8906e-01, 1.3765e-01,
8.4732e-02, -8.0962e-02, -6.6467e-02, -4.6670e-02, -1.5538e-02,
-2.2635e-01, 2.1714e-01, 2.3420e-01, -1.2549e-01, 1.6654e-02,
7.8636e-02, -5.5591e-02, 1.5804e-01, -1.9898e-03, -1.6906e-01,
-1.9567e-02, 3.1766e-02, 9.4322e-03, -7.3489e-02, -2.7695e-02,
1.2616e-01, 1.7542e-01, 1.9282e-01, -1.0990e-01, -2.4401e-01,
7.7423e-02, -1.5268e-01, -3.8916e-03, 6.2967e-02, 1.1859e-01,
8.4724e-02, -1.8199e-01, -7.2396e-02, -6.9568e-02, -7.0842e-02,
-9.4516e-02, 2.1872e-01, 2.2665e-01, -1.6960e-01, -2.2812e-01,
1.2222e-01, 1.3715e-01, -4.3735e-02, -7.3953e-02, -9.7945e-02,
-2.6904e-02, 8.7878e-02, -1.0421e-02, 1.6984e-02, 3.8456e-02,
1.7513e-01, -1.0124e-03, 3.8656e-02, 1.4228e-01, 9.2844e-03,
2.9087e-02, 5.2339e-02, -1.9062e-01, -1.0616e-01, 1.0397e-01,
-2.7843e-02, -7.8383e-02, -2.1809e-03]], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>),
'fc5.bias': tensor([ 0.0370, 0.0172, -0.0607], device='cuda:0',
grad_fn=<ReshapeAliasBackward0>)}