import jax
import jax.numpy as jnp
import numpy as np
import optax
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
import jax.random as random
import tensorflow_probability.substrates.jax as tfp
from flax import linen as nn
from typing import Any, Callable, Sequence
import seaborn as sns
import pandas as pd
from bayes_opt import BayesianOptimization
Imports
Create a simple 2d dataset
= random.multivariate_normal(
X =random.PRNGKey(0),
key=(100,),
shape=jnp.array([1, 3]),
mean=jnp.array([[1.0, -0.5], [-0.5, 2.0]]),
cov )
X.shape
(100, 2)
0], X[:, 1])
plt.scatter(X[:, # plt.gca().set_aspect("equal")
class Encoder(nn.Module):
int
bottleneck:
@nn.compact
def __call__(self, x):
= nn.Dense(5)(x)
x = nn.selu(x)
x = nn.Dense(features=self.bottleneck)(x)
x return x
class Decoder(nn.Module):
int
out:
@nn.compact
def __call__(self, x):
= nn.Dense(5)(x)
x = nn.selu(x)
x = nn.Dense(features=self.out)(x)
x return x
= Encoder(bottleneck=1)
enc
= Decoder(out=2) dec
= enc.init(random.PRNGKey(0), X)
params_enc = enc.apply(params_enc, X)
X_bottlenecked X_bottlenecked.shape
(100, 1)
print(enc.tabulate(random.PRNGKey(0), X))
print(dec.tabulate(random.PRNGKey(0), X_bottlenecked))
Encoder Summary
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ │ Encoder │ float32[100,2] │ float32[100,1] │ │
├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤
│ Dense_0 │ Dense │ float32[100,2] │ float32[100,5] │ bias: float32[5] │
│ │ │ │ │ kernel: float32[2,5] │
│ │ │ │ │ │
│ │ │ │ │ 15 (60 B) │
├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤
│ Dense_1 │ Dense │ float32[100,5] │ float32[100,1] │ bias: float32[1] │
│ │ │ │ │ kernel: float32[5,1] │
│ │ │ │ │ │
│ │ │ │ │ 6 (24 B) │
├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤
│ │ │ │ Total │ 21 (84 B) │
└─────────┴─────────┴────────────────┴────────────────┴──────────────────────┘
Total Parameters: 21 (84 B)
Decoder Summary
┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ │ Decoder │ float32[100,1] │ float32[100,2] │ │
├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤
│ Dense_0 │ Dense │ float32[100,1] │ float32[100,5] │ bias: float32[5] │
│ │ │ │ │ kernel: float32[1,5] │
│ │ │ │ │ │
│ │ │ │ │ 10 (40 B) │
├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤
│ Dense_1 │ Dense │ float32[100,5] │ float32[100,2] │ bias: float32[2] │
│ │ │ │ │ kernel: float32[5,2] │
│ │ │ │ │ │
│ │ │ │ │ 12 (48 B) │
├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤
│ │ │ │ Total │ 22 (88 B) │
└─────────┴─────────┴────────────────┴────────────────┴──────────────────────┘
Total Parameters: 22 (88 B)
class AE(nn.Module):
int
bottleneck: int
out: def setup(self):
# Alternative to @nn.compact -> explicitly define modules
# Better for later when we want to access the encoder and decoder explicitly
self.encoder = Encoder(bottleneck=self.bottleneck)
self.decoder = Decoder(out=self.out)
def __call__(self, x):
= self.encoder(x)
z = self.decoder(z)
x_hat return x_hat
= 1
bottleneck_size = X.shape[1]
out_size = AE(bottleneck_size, out_size) ae
ae
AE(
# attributes
bottleneck = 1
out = 2
)
print(ae.tabulate(random.PRNGKey(0), X))
AE Summary
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│ │ AE │ float32[100,2] │ float32[100,2] │ │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ encoder │ Encoder │ float32[100,2] │ float32[100,1] │ │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ encoder/Dense_0 │ Dense │ float32[100,2] │ float32[100,5] │ bias: │
│ │ │ │ │ float32[5] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[2,5] │
│ │ │ │ │ │
│ │ │ │ │ 15 (60 B) │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ encoder/Dense_1 │ Dense │ float32[100,5] │ float32[100,1] │ bias: │
│ │ │ │ │ float32[1] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[5,1] │
│ │ │ │ │ │
│ │ │ │ │ 6 (24 B) │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ decoder │ Decoder │ float32[100,1] │ float32[100,2] │ │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ decoder/Dense_0 │ Dense │ float32[100,1] │ float32[100,5] │ bias: │
│ │ │ │ │ float32[5] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[1,5] │
│ │ │ │ │ │
│ │ │ │ │ 10 (40 B) │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ decoder/Dense_1 │ Dense │ float32[100,5] │ float32[100,2] │ bias: │
│ │ │ │ │ float32[2] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[5,2] │
│ │ │ │ │ │
│ │ │ │ │ 12 (48 B) │
├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤
│ │ │ │ Total │ 43 (172 B) │
└─────────────────┴─────────┴────────────────┴────────────────┴────────────────┘
Total Parameters: 43 (172 B)
= ae.init(random.PRNGKey(0), X)
params params
FrozenDict({
params: {
encoder: {
Dense_0: {
kernel: DeviceArray([[ 0.17535934, -1.0953957 , 0.69273657, -0.26352578,
0.63077825],
[ 0.36360174, -0.73782593, -0.5395247 , -0.41536337,
-0.30090812]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: DeviceArray([[-0.64744544],
[ 0.4855265 ],
[-0.82133824],
[ 0.62454295],
[ 0.6013553 ]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
decoder: {
Dense_0: {
kernel: DeviceArray([[-0.5305567 , 1.1100855 , -0.31129056, 0.43152457,
-0.09589562]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
kernel: DeviceArray([[-0.76956064, 0.13031492],
[ 0.11736098, 0.47368795],
[-0.12549445, -0.31066778],
[-0.4392067 , -0.9067152 ],
[-0.86761785, 0.42325035]], dtype=float32),
bias: DeviceArray([0., 0.], dtype=float32),
},
},
},
})
= ae.apply(params, X)
X_hat X_hat.shape
(100, 2)
try:
ae.encoderexcept:
pass
# Trying to figure this out
# https://github.com/google/flax/discussions/2602
# Encoded values/latent representation
= Encoder(1).apply({"params": params["params"]["encoder"]}, X).flatten()
encoded_1d encoded_1d
DeviceArray([-2.4718695, -2.1964364, -2.6823573, -2.4936147, -1.7122931,
-1.8346143, -2.0767107, -1.8570523, -1.7632042, -2.067935 ,
-2.2317708, -2.14561 , -1.0023856, -2.1458383, -2.3645976,
-1.9418356, -2.7020268, -1.6407721, -1.8281609, -2.2202983,
-2.517499 , -2.5888596, -2.0095935, -2.4470625, -2.18571 ,
-1.9742887, -1.8921608, -2.245328 , -0.8897901, -2.5329056,
-2.2861118, -1.5862433, -2.2295656, -2.496296 , -2.404385 ,
-2.0180435, -1.8416756, -1.858724 , -2.0980945, -1.777173 ,
-2.0027544, -2.1870096, -2.44952 , -1.7563678, -1.5761943,
-2.3097022, -2.0295165, -2.9528203, -2.2042174, -1.9090188,
-1.8868417, -2.4206855, -2.143362 , -1.880422 , -2.5127397,
-2.1454868, -2.0043788, -2.570388 , -2.5082102, -2.3339696,
-1.8621875, -2.4201612, -2.561397 , -2.0498512, -1.6772006,
-1.6392376, -2.3855271, -1.8138398, -3.3776197, -2.3745804,
-2.6683671, -1.8609927, -1.4205931, -1.8123009, -2.236284 ,
-2.2161927, -2.5204146, -2.0504622, -2.1548996, -1.6896895,
-1.3192847, -2.2909331, -2.1295016, -2.0703764, -1.9394028,
-2.041992 , -1.8279521, -1.690125 , -2.7230937, -2.3157165,
-1.7527001, -2.2544892, -2.6310122, -2.0703619, -2.2476096,
-1.8941168, -1.5398859, -1.5742403, -2.375471 , -1.9361446], dtype=float32)
def plot_2d_reconstruction(X, params, model, trained = False):
= model.apply(params, X)
X_hat 0], X[:, 1], label="Original Data")
plt.scatter(X[:, 0], X_hat[:, 1], label="Reconstructed Data")
plt.scatter(X_hat[:, if trained:
"Trained")
plt.title(else:
"Untrained") plt.title(
False) plot_2d_reconstruction(X, params, ae,
Define the Loss function
\(\ell_2\) penalty
= X - X_hat diff
diff.shape
(100, 2)
5] diff[:
DeviceArray([[-0.46981597, 5.271835 ],
[ 1.6502905 , 3.6781619 ],
[ 1.8507848 , 5.0589485 ],
[ 2.8690844 , 4.5646677 ],
[ 0.4905889 , 2.8893166 ]], dtype=float32)
**2).sum(axis=1).mean() / 2 (diff
DeviceArray(7.9555416, dtype=float32)
**2).sum(axis=1)[:5] (diff
DeviceArray([28.01297 , 16.252333, 29.018364, 29.067837, 8.588828], dtype=float32)
ord=2, axis=1) ** 2).mean() / 2 (jnp.linalg.norm(diff,
DeviceArray(7.955541, dtype=float32)
from sklearn.metrics import mean_squared_error
mean_squared_error(X, X_hat)
7.9555407
print(2 * optax.l2_loss(X_hat, X).mean())
"""
Multplying by two
Docstring says:
Calculates the L2 loss for a set of predictions.
Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning"
by Bishop, but not "The Elements of Statistical Learning" by Tibshirani.
"""
7.9555416
'\n\nMultplying by two\nDocstring says:\nCalculates the L2 loss for a set of predictions.\n\nNote: the 0.5 term is standard in "Pattern Recognition and Machine Learning"\nby Bishop, but not "The Elements of Statistical Learning" by Tibshirani.\n'
@jax.jit
def loss(params, X):
= ae.apply(params, X)
X_hat return 2 * optax.l2_loss(X_hat, X).mean()
loss(params, X)
DeviceArray(7.9555416, dtype=float32)
Defining the train function
def train(
X: jnp.array,
optimizer: optax._src.base.GradientTransformation,
model: nn.Module,
key_param: jax.random.PRNGKey,int=500,
n_iter: int=10
print_every:
):= np.zeros(n_iter)
loss_array def loss(params, X):
= model.apply(params, X)
X_hat return 2 * optax.l2_loss(X_hat, X).mean()
= model.init(key_param, X)
params = optimizer.init(params)
opt_state = jax.value_and_grad(loss)
loss_grad_fn
for i in range(n_iter):
= loss_grad_fn(params, X)
loss_val, grads = loss_val.item()
loss_array[i] = optimizer.update(grads, opt_state)
updates, opt_state = optax.apply_updates(params, updates)
params if i % print_every == 0:
print("Loss step {}: ".format(i), loss_val)
return params, loss_array
= train(
optimized_params, loss_array =0.1), ae, jax.random.PRNGKey(0), n_iter=30
X, optax.adam(learning_rate )
Loss step 0: 7.9555416
Loss step 10: 1.3104575
Loss step 20: 0.544944
plt.plot(loss_array)"Iterations")
plt.xlabel(= plt.ylabel("Reconstruction loss") _
True) plot_2d_reconstruction(X, optimized_params, ae,
from sklearn import datasets
= datasets.load_digits() digits
= jnp.array(digits["data"])
X = digits["target"] y
X.shape
(1797, 64)
1].reshape(8, 8), cmap="Greys")
plt.imshow(X[1] y[
1
= 2
bn = AE(bn, X.shape[1])
ae_digits
ae_digits
AE(
# attributes
bottleneck = 2
out = 64
)
print(ae_digits.tabulate(random.PRNGKey(0), X))
AE Summary
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│ │ AE │ float32[1797,… │ float32[1797,6… │ │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ encoder │ Encoder │ float32[1797,… │ float32[1797,2] │ │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ encoder/Dense… │ Dense │ float32[1797,… │ float32[1797,5] │ bias: │
│ │ │ │ │ float32[5] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[64,5] │
│ │ │ │ │ │
│ │ │ │ │ 325 (1.3 KB) │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ encoder/Dense… │ Dense │ float32[1797,… │ float32[1797,2] │ bias: │
│ │ │ │ │ float32[2] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[5,2] │
│ │ │ │ │ │
│ │ │ │ │ 12 (48 B) │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ decoder │ Decoder │ float32[1797,… │ float32[1797,6… │ │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ decoder/Dense… │ Dense │ float32[1797,… │ float32[1797,5] │ bias: │
│ │ │ │ │ float32[5] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[2,5] │
│ │ │ │ │ │
│ │ │ │ │ 15 (60 B) │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ decoder/Dense… │ Dense │ float32[1797,… │ float32[1797,6… │ bias: │
│ │ │ │ │ float32[64] │
│ │ │ │ │ kernel: │
│ │ │ │ │ float32[5,64] │
│ │ │ │ │ │
│ │ │ │ │ 384 (1.5 KB) │
├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤
│ │ │ │ Total │ 736 (2.9 KB) │
└────────────────┴─────────┴────────────────┴─────────────────┴────────────────┘
Total Parameters: 736 (2.9 KB)
= ae_digits.init(random.PRNGKey(0), X) params_digits
lambda x: x.shape, params_digits) jax.tree_util.tree_map(
FrozenDict({
params: {
decoder: {
Dense_0: {
bias: (5,),
kernel: (2, 5),
},
Dense_1: {
bias: (64,),
kernel: (5, 64),
},
},
encoder: {
Dense_0: {
bias: (5,),
kernel: (64, 5),
},
Dense_1: {
bias: (2,),
kernel: (5, 2),
},
},
},
})
def plot_encoding_2dim(encoder, params):
assert encoder.bottleneck >= 2
= encoder.apply({"params": params["params"]["encoder"]}, X)
X_low = pd.DataFrame(X_low)
df "label"] = y
df[="label", palette="bright") sns.pairplot(df, hue
Untrained encodings
=bn), params_digits) plot_encoding_2dim(Encoder(bottleneck
= ae_digits.apply(params_digits, X) X_recon
def plot_orig_recon(index=0):
= plt.subplots(sharex=True, ncols=2)
fig, ax 0].imshow(X[index].reshape(8, 8), cmap="Greys")
ax[1].imshow(X_recon[index].reshape(8, 8), cmap="Greys")
ax[0].set_title("Original")
ax[1].set_title("Reconstructed") ax[
5) plot_orig_recon(
= train(
optimized_params_digits, loss_array_digits =0.01), ae_digits, jax.random.PRNGKey(0), n_iter=1000
X, optax.adam(learning_rate )
Loss step 0: 90.91908
Loss step 10: 62.609577
Loss step 20: 58.390884
Loss step 30: 53.54514
Loss step 40: 45.062607
Loss step 50: 33.541103
Loss step 60: 25.167671
Loss step 70: 21.107908
Loss step 80: 19.424128
Loss step 90: 18.734087
Loss step 100: 18.47802
Loss step 110: 18.390646
Loss step 120: 18.352455
Loss step 130: 18.333141
Loss step 140: 18.321236
Loss step 150: 18.311743
Loss step 160: 18.3032
Loss step 170: 18.295115
Loss step 180: 18.287226
Loss step 190: 18.279234
Loss step 200: 18.270723
Loss step 210: 18.26098
Loss step 220: 18.2499
Loss step 230: 18.237106
Loss step 240: 18.221647
Loss step 250: 18.20243
Loss step 260: 18.177717
Loss step 270: 18.14539
Loss step 280: 18.105865
Loss step 290: 18.058249
Loss step 300: 18.000141
Loss step 310: 17.931208
Loss step 320: 17.84967
Loss step 330: 17.755304
Loss step 340: 17.65073
Loss step 350: 17.537819
Loss step 360: 17.418528
Loss step 370: 17.293976
Loss step 380: 17.164043
Loss step 390: 17.029558
Loss step 400: 16.89464
Loss step 410: 16.760334
Loss step 420: 16.626553
Loss step 430: 16.493797
Loss step 440: 16.362513
Loss step 450: 16.234201
Loss step 460: 16.11052
Loss step 470: 15.992949
Loss step 480: 15.883502
Loss step 490: 15.783846
Loss step 500: 15.694724
Loss step 510: 15.615571
Loss step 520: 15.54589
Loss step 530: 15.483993
Loss step 540: 15.427973
Loss step 550: 15.376085
Loss step 560: 15.326871
Loss step 570: 15.280196
Loss step 580: 15.23521
Loss step 590: 15.191253
Loss step 600: 15.149132
Loss step 610: 15.109302
Loss step 620: 15.071858
Loss step 630: 15.037474
Loss step 640: 15.005837
Loss step 650: 14.977009
Loss step 660: 14.950782
Loss step 670: 14.927103
Loss step 680: 14.905551
Loss step 690: 14.885867
Loss step 700: 14.867877
Loss step 710: 14.851396
Loss step 720: 14.836317
Loss step 730: 14.8224125
Loss step 740: 14.809575
Loss step 750: 14.797547
Loss step 760: 14.786259
Loss step 770: 14.775562
Loss step 780: 14.76545
Loss step 790: 14.755904
Loss step 800: 14.746771
Loss step 810: 14.738021
Loss step 820: 14.729595
Loss step 830: 14.721415
Loss step 840: 14.713423
Loss step 850: 14.705618
Loss step 860: 14.697898
Loss step 870: 14.690201
Loss step 880: 14.682494
Loss step 890: 14.674812
Loss step 900: 14.667133
Loss step 910: 14.6593275
Loss step 920: 14.651322
Loss step 930: 14.643042
Loss step 940: 14.634569
Loss step 950: 14.625735
Loss step 960: 14.616413
Loss step 970: 14.6066065
Loss step 980: 14.596094
Loss step 990: 14.58464
plt.plot(loss_array_digits)
Trained encodings
=bn), optimized_params_digits) plot_encoding_2dim(Encoder(bottleneck
Reconstruction
= ae_digits.apply(optimized_params_digits, X)
X_recon 4) plot_orig_recon(
= ae.apply(params, X) X_reconstructed
= jnp.square(X - X_reconstructed).sum(axis=1)
errs = pd.DataFrame({"error": errs, "label": y})
err_df "label").mean() err_df.groupby(
error | |
---|---|
label | |
0 | 1067.159668 |
1 | 1253.397217 |
2 | 1187.446655 |
3 | 730.839417 |
4 | 919.732239 |
5 | 1103.442505 |
6 | 913.172607 |
7 | 1309.424438 |
8 | 892.981750 |
9 | 891.891907 |
= pd.DataFrame({"error": errs, "label": y}) err_df
"label").mean() err_df.groupby(
error | |
---|---|
label | |
0 | 1067.159668 |
1 | 1253.397217 |
2 | 1187.446655 |
3 | 730.839417 |
4 | 919.732239 |
5 | 1103.442505 |
6 | 913.172607 |
7 | 1309.424438 |
8 | 892.981750 |
9 | 891.891907 |
Convoluational AE
class ConvEncoder(nn.Module):
int
bottleneck:
@nn.compact
def __call__(self, x):
= x.shape[0] # x is nx64
n = x.reshape(n, 8, 8, 1)
x = nn.Conv(features=4, kernel_size=(2, 2), strides=1, padding=0)(
x
x# 8X8X1 -> 6x6X4
) = nn.selu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) # 6x6x4 --> 3x3x4
x = nn.selu(x)
x = x.reshape(n, -1) # N X 3x3x4 -> N X 36
x = nn.Dense(self.bottleneck)(x)
x return x
= ConvEncoder(2)
ce #print(ce.tabulate(random.PRNGKey(0), X))
print(ce.tabulate(random.PRNGKey(0), X, console_kwargs={"width": 120}))
ConvEncoder Summary
┏━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ │ ConvEncoder │ float32[1797,64] │ float32[1797,2] │ │
├─────────┼─────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ Conv_0 │ Conv │ float32[1797,8,8,1] │ float32[1797,7,7,4] │ bias: float32[4] │
│ │ │ │ │ kernel: float32[2,2,1,4] │
│ │ │ │ │ │
│ │ │ │ │ 20 (80 B) │
├─────────┼─────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ Dense_0 │ Dense │ float32[1797,36] │ float32[1797,2] │ bias: float32[2] │
│ │ │ │ │ kernel: float32[36,2] │
│ │ │ │ │ │
│ │ │ │ │ 74 (296 B) │
├─────────┼─────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ │ │ │ Total │ 94 (376 B) │
└─────────┴─────────────┴─────────────────────┴─────────────────────┴──────────────────────────┘
Total Parameters: 94 (376 B)
class ConvDecoder(nn.Module):
@nn.compact
def __call__(self, x):
= nn.Dense(36)(x) # Nx2 --> Nx36
x = nn.selu(x)
x = x.reshape(-1, 3, 3, 4) # NX3X3X4
x = nn.ConvTranspose(features=4, kernel_size=(2, 2), strides=(2, 2))(
x
x# 3x3x4 -> 6x6X4
) = nn.selu(x)
x = nn.Conv(features=1, kernel_size=(1, 1), strides=1, padding=1)(
x
x# 6x6x4 -> 8x8x1
) = x.reshape(-1, 64)
x return x
= ConvDecoder()
cd print(
cd.tabulate(0),
random.PRNGKey(=jax.random.PRNGKey(0), shape=(1797, 2)),
jax.random.normal(key={"width": 120},
console_kwargs
) )
ConvDecoder Summary
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ │ ConvDecoder │ float32[1797,2] │ float32[1797,64] │ │
├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ Dense_0 │ Dense │ float32[1797,2] │ float32[1797,36] │ bias: float32[36] │
│ │ │ │ │ kernel: float32[2,36] │
│ │ │ │ │ │
│ │ │ │ │ 108 (432 B) │
├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ ConvTranspose_0 │ ConvTranspose │ float32[1797,3,3,4] │ float32[1797,6,6,4] │ bias: float32[4] │
│ │ │ │ │ kernel: float32[2,2,4,4] │
│ │ │ │ │ │
│ │ │ │ │ 68 (272 B) │
├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ Conv_0 │ Conv │ float32[1797,6,6,4] │ float32[1797,8,8,1] │ bias: float32[1] │
│ │ │ │ │ kernel: float32[1,1,4,1] │
│ │ │ │ │ │
│ │ │ │ │ 5 (20 B) │
├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ │ │ │ Total │ 181 (724 B) │
└─────────────────┴───────────────┴─────────────────────┴─────────────────────┴──────────────────────────┘
Total Parameters: 181 (724 B)
class ConvAE(nn.Module):
int
bottleneck:
def setup(self):
# Alternative to @nn.compact -> explicitly define modules
# Better for later when we want to access the encoder and decoder explicitly
self.encoder = ConvEncoder(bottleneck=self.bottleneck)
self.decoder = ConvDecoder()
def __call__(self, x):
= self.encoder(x)
z = self.decoder(z)
x_hat return x_hat
= ConvAE(2)
cae print(
cae.tabulate(0),
random.PRNGKey(
X,={"width": 120},
console_kwargs
) )
ConvAE Summary
┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ │ ConvAE │ float32[1797,64] │ float32[1797,64] │ │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ encoder │ ConvEncoder │ float32[1797,64] │ float32[1797,2] │ │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ encoder/Conv_0 │ Conv │ float32[1797,8,8,1] │ float32[1797,7,7,4] │ bias: float32[4] │
│ │ │ │ │ kernel: float32[2,2,1,4] │
│ │ │ │ │ │
│ │ │ │ │ 20 (80 B) │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ encoder/Dense_0 │ Dense │ float32[1797,36] │ float32[1797,2] │ bias: float32[2] │
│ │ │ │ │ kernel: float32[36,2] │
│ │ │ │ │ │
│ │ │ │ │ 74 (296 B) │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ decoder │ ConvDecoder │ float32[1797,2] │ float32[1797,64] │ │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ decoder/Dense_0 │ Dense │ float32[1797,2] │ float32[1797,36] │ bias: float32[36] │
│ │ │ │ │ kernel: float32[2,36] │
│ │ │ │ │ │
│ │ │ │ │ 108 (432 B) │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ decoder/ConvTranspose_0 │ ConvTranspose │ float32[1797,3,3,4] │ float32[1797,6,6,4] │ bias: float32[4] │
│ │ │ │ │ kernel: float32[2,2,4,4] │
│ │ │ │ │ │
│ │ │ │ │ 68 (272 B) │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ decoder/Conv_0 │ Conv │ float32[1797,6,6,4] │ float32[1797,8,8,1] │ bias: float32[1] │
│ │ │ │ │ kernel: float32[1,1,4,1] │
│ │ │ │ │ │
│ │ │ │ │ 5 (20 B) │
├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤
│ │ │ │ Total │ 275 (1.1 KB) │
└─────────────────────────┴───────────────┴─────────────────────┴─────────────────────┴──────────────────────────┘
Total Parameters: 275 (1.1 KB)
= cae.init(random.PRNGKey(0), X) params
=2), params) plot_encoding_2dim(ConvEncoder(bottleneck
= train(
optimized_params_digits_cae, loss_array_digits_cae =0.01), cae, jax.random.PRNGKey(0), n_iter=1000, print_every=50
X, optax.adam(learning_rate )
Loss step 0: 61.916904
Loss step 50: 30.379993
Loss step 100: 27.855324
Loss step 150: 26.851124
Loss step 200: 25.77603
Loss step 250: 25.184359
Loss step 300: 24.772747
Loss step 350: 24.351847
Loss step 400: 24.091908
Loss step 450: 23.887573
Loss step 500: 23.72832
Loss step 550: 23.607725
Loss step 600: 23.514961
Loss step 650: 23.419945
Loss step 700: 23.363184
Loss step 750: 23.30127
Loss step 800: 23.258532
Loss step 850: 23.206999
Loss step 900: 23.162285
Loss step 950: 23.13027
=2), optimized_params_digits_cae) plot_encoding_2dim(ConvEncoder(bottleneck
BayesOpt for optimizing the latent dimension
def black_box_function(x, y):
"""Function with unknown internals we wish to maximize.
This is just serving as an example, for all intents and
purposes think of the internals of this function, i.e.: the process
which generates its output values, as unknown.
"""
= int(x)
x = int(y)
y return function_discrete(x, y)
def function_discrete(x, y):
assert type(x) ==int
return -(x**2) - (y - 1) ** 2 + 1
= {"x": (2, 4), "y": (-3, 3)} pbounds
= BayesianOptimization(
optimizer =black_box_function,
f=pbounds,
pbounds=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
verbose=1,
random_state )
optimizer.maximize()
| iter | target | x | y |
-------------------------------------------------
| 1 | -3.0 | 2.834 | 1.322 |
| 2 | -7.0 | 2.0 | -1.186 |
| 3 | -12.0 | 2.294 | -2.446 |
| 4 | -4.0 | 2.373 | -0.9266 |
| 5 | -4.0 | 2.794 | 0.2329 |
| 6 | -15.0 | 4.0 | 1.331 |
| 7 | -4.0 | 2.348 | 0.8879 |
| 8 | -3.0 | 2.797 | 1.257 |
| 9 | -4.0 | 2.064 | 2.229 |
| 10 | -9.0 | 3.657 | -0.9428 |
| 11 | -7.0 | 2.901 | 3.0 |
| 12 | -4.0 | 2.0 | -0.1486 |
| 13 | -31.0 | 4.0 | -3.0 |
| 14 | -7.0 | 2.0 | 3.0 |
| 15 | -3.0 | 2.0 | 1.539 |
| 16 | -3.0 | 2.512 | 1.792 |
| 17 | -19.0 | 4.0 | 3.0 |
| 18 | -4.0 | 2.831 | -0.4655 |
| 19 | -4.0 | 2.402 | -0.3286 |
| 20 | -9.0 | 3.539 | 0.08748 |
| 21 | -7.0 | 2.841 | -1.217 |
| 22 | -4.0 | 2.764 | 2.245 |
| 23 | -4.0 | 2.0 | 0.4436 |
| 24 | -3.0 | 2.469 | 1.423 |
| 25 | -3.0 | 2.0 | 1.16 |
| 26 | -3.0 | 2.787 | 1.714 |
| 27 | -4.0 | 2.932 | 0.7853 |
| 28 | -3.0 | 2.647 | 1.526 |
| 29 | -3.0 | 2.148 | 1.373 |
| 30 | -3.0 | 2.212 | 1.795 |
=================================================
max optimizer.
{'target': -3.0, 'params': {'x': 2.8340440094051482, 'y': 1.3219469606529488}}
int(v) for k, v in optimizer.max["params"].items()} {k:
{'x': 2, 'y': 1}
2, 1) function_discrete(
-3
Let us keep a separate validation set
def loss_model(params, X, model):
= model.apply(params, X)
X_hat = X - X_hat
diff return (diff**2).sum(axis=1).mean() / X.shape[1]
from functools import partial
= partial(loss_model, model=cae)
e e(params, X)
DeviceArray(61.916904, dtype=float32)
def validation_loss_discrete(bn):
assert type(bn) == int
# Train the model on bn sized bottleneck
= ConvAE(bn)
cae = jax.jit(partial(loss_model, model=cae))
loss_fn_concrete = jax.value_and_grad(loss_fn_concrete)
loss_grad_fn = optax.adam(learning_rate=1e-2)
tx = cae.init(random.PRNGKey(0), X_train)
params = tx.init(params)
opt_state print(f"--------Bottleneck of Size: {bn}-------------")
for i in range(30):
= loss_grad_fn(params, X_train)
loss_val, grads = tx.update(grads, opt_state)
updates, opt_state = optax.apply_updates(params, updates)
params
if i % 5 == 0:
print("Loss step {}: ".format(i), loss_val)
print(f"--------End-------------")
# Evaluate on validation dataset
return loss_fn_concrete(params, X_validation)
= X[:1000], X[1000:] X_train, X_validation
2) validation_loss_discrete(
--------Bottleneck of Size: 2-------------
Loss step 0: 62.27715
Loss step 5: 58.5037
Loss step 10: 53.984245
Loss step 15: 49.513382
Loss step 20: 43.078316
Loss step 25: 38.30596
--------End-------------
DeviceArray(36.75615, dtype=float32)
def validation_loss_bb(bn):
= int(bn)
bn_int return -validation_loss_discrete(bn_int)
2.5) validation_loss_bb(
--------Bottleneck of Size: 2-------------
Loss step 0: 62.27715
Loss step 5: 58.5037
Loss step 10: 53.984245
Loss step 15: 49.513382
Loss step 20: 43.078316
Loss step 25: 38.30596
--------End-------------
DeviceArray(-36.75615, dtype=float32)
= {"bn": (1, 40)}
pbounds = BayesianOptimization(
optimizer =validation_loss_bb,
f=pbounds,
pbounds=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
verbose=1,
random_state )
=8) optimizer.maximize(n_iter
| iter | target | bn |
-------------------------------------
--------Bottleneck of Size: 17-------------
Loss step 0: 62.85297
Loss step 5: 52.85449
Loss step 10: 40.903214
Loss step 15: 35.32036
Loss step 20: 35.3193
Loss step 25: 33.33418
--------End-------------
| 1 | -32.36 | 17.26 |
--------Bottleneck of Size: 29-------------
Loss step 0: 64.064514
Loss step 5: 53.85875
Loss step 10: 47.26749
Loss step 15: 43.828564
Loss step 20: 41.847286
Loss step 25: 39.23966
--------End-------------
| 2 | -37.29 | 29.09 |
--------Bottleneck of Size: 1-------------
Loss step 0: 60.969757
Loss step 5: 58.92785
Loss step 10: 53.683678
Loss step 15: 49.58035
Loss step 20: 45.86102
Loss step 25: 44.17104
--------End-------------
| 3 | -42.48 | 1.004 |
--------Bottleneck of Size: 12-------------
Loss step 0: 63.704227
Loss step 5: 57.338806
Loss step 10: 49.537926
Loss step 15: 41.210827
Loss step 20: 38.469257
Loss step 25: 35.276833
--------End-------------
| 4 | -34.07 | 12.79 |
--------Bottleneck of Size: 6-------------
Loss step 0: 61.450924
Loss step 5: 55.82548
Loss step 10: 47.88899
Loss step 15: 40.131763
Loss step 20: 37.62544
Loss step 25: 35.873016
--------End-------------
| 5 | -34.2 | 6.723 |
--------Bottleneck of Size: 20-------------
Loss step 0: 61.81845
Loss step 5: 56.358246
Loss step 10: 51.92751
Loss step 15: 47.312576
Loss step 20: 42.146885
Loss step 25: 37.025486
--------End-------------
| 6 | -33.86 | 20.39 |
--------Bottleneck of Size: 40-------------
Loss step 0: 61.5667
Loss step 5: 49.598972
Loss step 10: 42.639145
Loss step 15: 39.22532
Loss step 20: 36.597954
Loss step 25: 34.528015
--------End-------------
| 7 | -32.67 | 40.0 |
--------Bottleneck of Size: 36-------------
Loss step 0: 62.303535
Loss step 5: 52.075367
Loss step 10: 44.435425
Loss step 15: 40.889286
Loss step 20: 39.280178
Loss step 25: 37.09512
--------End-------------
| 8 | -35.77 | 36.05 |
--------Bottleneck of Size: 9-------------
Loss step 0: 63.35566
Loss step 5: 52.45499
Loss step 10: 43.281902
Loss step 15: 37.028984
Loss step 20: 35.006325
Loss step 25: 33.583298
--------End-------------
| 9 | -33.01 | 9.596 |
--------Bottleneck of Size: 24-------------
Loss step 0: 62.888515
Loss step 5: 52.035835
Loss step 10: 42.154068
Loss step 15: 36.804348
Loss step 20: 34.53549
Loss step 25: 32.37921
--------End-------------
| 10 | -30.08 | 24.26 |
--------Bottleneck of Size: 25-------------
Loss step 0: 63.406757
Loss step 5: 50.291225
Loss step 10: 41.73214
Loss step 15: 38.421593
Loss step 20: 37.0491
Loss step 25: 34.847046
--------End-------------
| 11 | -33.89 | 25.81 |
--------Bottleneck of Size: 22-------------
Loss step 0: 62.303898
Loss step 5: 53.713398
Loss step 10: 47.806355
Loss step 15: 43.550034
Loss step 20: 42.033653
Loss step 25: 39.68766
--------End-------------
| 12 | -38.51 | 22.8 |
--------Bottleneck of Size: 24-------------
Loss step 0: 62.888515
Loss step 5: 52.035835
Loss step 10: 42.154068
Loss step 15: 36.804348
Loss step 20: 34.53549
Loss step 25: 32.37921
--------End-------------
| 13 | -30.08 | 24.3 |
=====================================
max optimizer.
{'target': -30.082199096679688, 'params': {'bn': 24.25939633195359}}
VAE
class VAE_Encoder(nn.Module):
int
bottleneck:
@nn.compact
def __call__(self, x):
= nn.Dense(5)(x)
x = nn.selu(x)
x = nn.Dense(features=self.bottleneck)(x)
mu = nn.Dense(features=self.bottleneck)(x)
log_std return mu, log_std
def reparameterize(mu, log_std, key=random.PRNGKey(0), samples=1):
= jnp.exp(log_std)
std = random.normal(key=key, shape=(samples,))
eps return mu + eps * std
= reparameterize(2, jnp.log(1), samples=5000)
samples
sns.kdeplot(samples)f"Mean:{jnp.mean(samples):0.2f}, stddev: {jnp.std(samples):0.2f}") plt.title(
Text(0.5, 1.0, 'Mean:2.00, stddev: 1.00')
class VAE(nn.Module):
int
bottleneck: int
out:
def setup(self):
# Alternative to @nn.compact -> explicitly define modules
# Better for later when we want to access the encoder and decoder explicitly
self.encoder = VAE_Encoder(bottleneck=self.bottleneck)
self.decoder = Decoder(out=self.out)
def __call__(self, x, rng=random.PRNGKey(0)):
= self.encoder(x)
mu, log_std = reparameterize(mu, log_std, key=rng)
z = self.decoder(z)
x_hat return x_hat, mu, log_std
= VAE(bottleneck=2, out=64) vae
= vae.init(random.PRNGKey(10), X) params
apply(params, X)[0][0].reshape(8, 8)) plt.imshow(vae.
apply(params, X, random.PRNGKey(10))[0][0].reshape(8, 8) vae.
DeviceArray([[ -3999.399 , 6091.6396 , -2634.2932 , 307.47302 ,
3932.0298 , 1823.3352 , 3852.157 , 5576.5605 ],
[ -8809.304 , 5299.91 , 286.5227 , 1059.3925 ,
-951.62537 , -6623.4824 , -1463.6239 , 16223.624 ],
[ -5279.1323 , -7333.815 , -71.1485 , 5679.2773 ,
1384.2794 , 8326.92 , -1747.943 , -4802.341 ],
[ 403.3739 , 13455.688 , -7414.195 , 7299.713 ,
1180.7408 , -328.49432 , 6619.1357 , 363.74713 ],
[ -4376.3506 , -2045.3063 , 2618.412 , -10890.402 ,
-3035.3848 , -3574.7527 , -5057.2593 , -1859.8529 ],
[ -53.99241 , 2318.109 , -1323.9087 , -6801.4814 ,
-7300.1553 , 865.4169 , 13349.937 , 865.3773 ],
[ 37.275284, -3962.8357 , 1771.9886 , -7992.7188 ,
4896.562 , -17371.383 , 4737.3887 , 7307.3384 ],
[ -221.0234 , -5475.8447 , 4189.172 , -1095.9471 ,
-6452.915 , 3767.8381 , -10514.758 , -2311.0862 ]], dtype=float32)
= VAE_Encoder(2)
vae_e = vae_e.apply({"params": params["params"]["encoder"]}, X) mu, log_sigma
= tfp.distributions tfd
q
NameError: name 'q' is not defined
tfd.kl_divergence(q, p).shape
tfd.kl_divergence(q, p).mean()
q.stddev()
Loss
@jax.jit
def loss_vae(params, X, rng=random.PRNGKey(0)):
= vae.apply(params, X, rng)
X_hat, mu, log_sigma = tfd.Normal(loc=mu, scale=jnp.exp(log_sigma))
q = tfd.Normal(loc=0.0, scale=1.0)
p = tfd.kl_divergence(q, p).mean()
kl_loss
= X - X_hat
diff = (diff**2).sum(axis=1).mean() / X.shape[1]
recon_loss
return recon_loss + 0.0020 * kl_loss
4)) loss_vae(params, X, random.PRNGKey(
import optax
= 0.01
learning_rate = optax.adam(learning_rate=learning_rate)
tx = tx.init(params)
opt_state = jax.value_and_grad(loss_vae) loss_grad_fn
for i in range(2001):
= random.split(rng)
rng, key = loss_grad_fn(params, X, rng)
loss_val, grads = tx.update(grads, opt_state)
updates, opt_state = optax.apply_updates(params, updates)
params if i % 50 == 0:
print("Loss step {}: ".format(i), loss_val)
= vae.apply(params, X) X_recon, _, _
8) plot_orig_recon(
= Decoder(out=64)
dec = 10
N = jnp.linspace(-2, 2, N)
x_range = plt.subplots(ncols=N, sharey=True, figsize=(20, 4))
fig, ax for i in range(N):
ax[i].imshow(apply(
dec."params": params["params"]["decoder"]}, jnp.array([x_range[i], 0.0])
{8, 8),
).reshape(="Greys",
cmap )
def plot_encoding_2dim_vae(encoder, params):
assert encoder.bottleneck >= 2
= encoder.apply({"params": params["params"]["encoder"]}, X)
mu, log_sigma = pd.DataFrame(mu)
df "label"] = y
df[="label", palette="bright") sns.pairplot(df, hue
= VAE_Encoder(2)
vae_enc = vae_enc.apply({"params": params["params"]["encoder"]}, X)
mu, log_sigma # plot_encoding_2dim_vae(VAE_Encoder(2), params)
plot_encoding_2dim_vae(vae_enc, params)
TODO
- regular AE: Bayesopt for latent dimension
- generation from regular AE
- graph of reconstruction loss v/s latent dimension for regular AE
- GIF for walking in latent space for VAE
- Reconstruction as a factor of Recon + Beta X KL
- Get the Encoder from AE object directly
- Impact of MC samples
- Reconstruction v/s Expected Log Likelihood (confirm the trend is same for both)
- Cleanup code so that can be reused rather than copy pasting
- Sparse VAE
- Add references
- Add bib entry
- Consider CNNs for more realistic datasets
- https://lilianweng.github.io/posts/2018-08-12-vae/
- https://theaisummer.com/jax-tensorflow-pytorch/
- https://dmol.pub/dl/VAE.html