Autoencoders in JAX

A programming introduction to Autoencoders in JAX
ML
Author

Nipun Batra

Published

November 4, 2022

Imports

import jax
import jax.numpy as jnp
import numpy as np
import optax

import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

import jax.random as random
import tensorflow_probability.substrates.jax as tfp

from flax import linen as nn
from typing import Any, Callable, Sequence

import seaborn as sns
import pandas as pd

from bayes_opt import BayesianOptimization

Create a simple 2d dataset

X = random.multivariate_normal(
    key=random.PRNGKey(0),
    shape=(100,),
    mean=jnp.array([1, 3]),
    cov=jnp.array([[1.0, -0.5], [-0.5, 2.0]]),
)
X.shape
(100, 2)
plt.scatter(X[:, 0], X[:, 1])
# plt.gca().set_aspect("equal")

class Encoder(nn.Module):
    bottleneck: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(5)(x)
        x = nn.selu(x)
        x = nn.Dense(features=self.bottleneck)(x)
        return x
class Decoder(nn.Module):
    out: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(5)(x)
        x = nn.selu(x)
        x = nn.Dense(features=self.out)(x)
        return x
enc = Encoder(bottleneck=1)

dec = Decoder(out=2)
params_enc = enc.init(random.PRNGKey(0), X)
X_bottlenecked = enc.apply(params_enc, X)
X_bottlenecked.shape
(100, 1)
print(enc.tabulate(random.PRNGKey(0), X))

print(dec.tabulate(random.PRNGKey(0), X_bottlenecked))
                               Encoder Summary                                

┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓

┃ path     module   inputs          outputs         params               ┃

┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩

│         │ Encoder │ float32[100,2] │ float32[100,1] │                      │

├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤

│ Dense_0 │ Dense   │ float32[100,2] │ float32[100,5] │ bias: float32[5]     │

│         │         │                │                │ kernel: float32[2,5] │

│         │         │                │                │                      │

│         │         │                │                │ 15 (60 B)            │

├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤

│ Dense_1 │ Dense   │ float32[100,5] │ float32[100,1] │ bias: float32[1]     │

│         │         │                │                │ kernel: float32[5,1] │

│         │         │                │                │                      │

│         │         │                │                │ 6 (24 B)             │

├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤

│                                            Total  21 (84 B)            │

└─────────┴─────────┴────────────────┴────────────────┴──────────────────────┘

                                                                              

                         Total Parameters: 21 (84 B)                          







                               Decoder Summary                                

┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓

┃ path     module   inputs          outputs         params               ┃

┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩

│         │ Decoder │ float32[100,1] │ float32[100,2] │                      │

├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤

│ Dense_0 │ Dense   │ float32[100,1] │ float32[100,5] │ bias: float32[5]     │

│         │         │                │                │ kernel: float32[1,5] │

│         │         │                │                │                      │

│         │         │                │                │ 10 (40 B)            │

├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤

│ Dense_1 │ Dense   │ float32[100,5] │ float32[100,2] │ bias: float32[2]     │

│         │         │                │                │ kernel: float32[5,2] │

│         │         │                │                │                      │

│         │         │                │                │ 12 (48 B)            │

├─────────┼─────────┼────────────────┼────────────────┼──────────────────────┤

│                                            Total  22 (88 B)            │

└─────────┴─────────┴────────────────┴────────────────┴──────────────────────┘

                                                                              

                         Total Parameters: 22 (88 B)                          




class AE(nn.Module):
    bottleneck: int
    out: int
    def setup(self):
        # Alternative to @nn.compact -> explicitly define modules
        # Better for later when we want to access the encoder and decoder explicitly
        self.encoder = Encoder(bottleneck=self.bottleneck)
        self.decoder = Decoder(out=self.out)

    def __call__(self, x):

        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
bottleneck_size = 1
out_size = X.shape[1]
ae = AE(bottleneck_size, out_size)
ae
AE(
    # attributes
    bottleneck = 1
    out = 2
)
print(ae.tabulate(random.PRNGKey(0), X))
                                   AE Summary                                   

┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓

┃ path             module   inputs          outputs         params         ┃

┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩

│                 │ AE      │ float32[100,2] │ float32[100,2] │                │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│ encoder         │ Encoder │ float32[100,2] │ float32[100,1] │                │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│ encoder/Dense_0 │ Dense   │ float32[100,2] │ float32[100,5] │ bias:          │

│                 │         │                │                │ float32[5]     │

│                 │         │                │                │ kernel:        │

│                 │         │                │                │ float32[2,5]   │

│                 │         │                │                │                │

│                 │         │                │                │ 15 (60 B)      │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│ encoder/Dense_1 │ Dense   │ float32[100,5] │ float32[100,1] │ bias:          │

│                 │         │                │                │ float32[1]     │

│                 │         │                │                │ kernel:        │

│                 │         │                │                │ float32[5,1]   │

│                 │         │                │                │                │

│                 │         │                │                │ 6 (24 B)       │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│ decoder         │ Decoder │ float32[100,1] │ float32[100,2] │                │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│ decoder/Dense_0 │ Dense   │ float32[100,1] │ float32[100,5] │ bias:          │

│                 │         │                │                │ float32[5]     │

│                 │         │                │                │ kernel:        │

│                 │         │                │                │ float32[1,5]   │

│                 │         │                │                │                │

│                 │         │                │                │ 10 (40 B)      │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│ decoder/Dense_1 │ Dense   │ float32[100,5] │ float32[100,2] │ bias:          │

│                 │         │                │                │ float32[2]     │

│                 │         │                │                │ kernel:        │

│                 │         │                │                │ float32[5,2]   │

│                 │         │                │                │                │

│                 │         │                │                │ 12 (48 B)      │

├─────────────────┼─────────┼────────────────┼────────────────┼────────────────┤

│                                                    Total  43 (172 B)     │

└─────────────────┴─────────┴────────────────┴────────────────┴────────────────┘

                                                                                

                          Total Parameters: 43 (172 B)                          




params = ae.init(random.PRNGKey(0), X)
params
FrozenDict({
    params: {
        encoder: {
            Dense_0: {
                kernel: DeviceArray([[ 0.17535934, -1.0953957 ,  0.69273657, -0.26352578,
                               0.63077825],
                             [ 0.36360174, -0.73782593, -0.5395247 , -0.41536337,
                              -0.30090812]], dtype=float32),
                bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
            },
            Dense_1: {
                kernel: DeviceArray([[-0.64744544],
                             [ 0.4855265 ],
                             [-0.82133824],
                             [ 0.62454295],
                             [ 0.6013553 ]], dtype=float32),
                bias: DeviceArray([0.], dtype=float32),
            },
        },
        decoder: {
            Dense_0: {
                kernel: DeviceArray([[-0.5305567 ,  1.1100855 , -0.31129056,  0.43152457,
                              -0.09589562]], dtype=float32),
                bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
            },
            Dense_1: {
                kernel: DeviceArray([[-0.76956064,  0.13031492],
                             [ 0.11736098,  0.47368795],
                             [-0.12549445, -0.31066778],
                             [-0.4392067 , -0.9067152 ],
                             [-0.86761785,  0.42325035]], dtype=float32),
                bias: DeviceArray([0., 0.], dtype=float32),
            },
        },
    },
})
X_hat = ae.apply(params, X)
X_hat.shape
(100, 2)
try:
    ae.encoder
except:
    pass
    # Trying to figure this out
    # https://github.com/google/flax/discussions/2602
# Encoded values/latent representation
encoded_1d = Encoder(1).apply({"params": params["params"]["encoder"]}, X).flatten()
encoded_1d
DeviceArray([-2.4718695, -2.1964364, -2.6823573, -2.4936147, -1.7122931,
             -1.8346143, -2.0767107, -1.8570523, -1.7632042, -2.067935 ,
             -2.2317708, -2.14561  , -1.0023856, -2.1458383, -2.3645976,
             -1.9418356, -2.7020268, -1.6407721, -1.8281609, -2.2202983,
             -2.517499 , -2.5888596, -2.0095935, -2.4470625, -2.18571  ,
             -1.9742887, -1.8921608, -2.245328 , -0.8897901, -2.5329056,
             -2.2861118, -1.5862433, -2.2295656, -2.496296 , -2.404385 ,
             -2.0180435, -1.8416756, -1.858724 , -2.0980945, -1.777173 ,
             -2.0027544, -2.1870096, -2.44952  , -1.7563678, -1.5761943,
             -2.3097022, -2.0295165, -2.9528203, -2.2042174, -1.9090188,
             -1.8868417, -2.4206855, -2.143362 , -1.880422 , -2.5127397,
             -2.1454868, -2.0043788, -2.570388 , -2.5082102, -2.3339696,
             -1.8621875, -2.4201612, -2.561397 , -2.0498512, -1.6772006,
             -1.6392376, -2.3855271, -1.8138398, -3.3776197, -2.3745804,
             -2.6683671, -1.8609927, -1.4205931, -1.8123009, -2.236284 ,
             -2.2161927, -2.5204146, -2.0504622, -2.1548996, -1.6896895,
             -1.3192847, -2.2909331, -2.1295016, -2.0703764, -1.9394028,
             -2.041992 , -1.8279521, -1.690125 , -2.7230937, -2.3157165,
             -1.7527001, -2.2544892, -2.6310122, -2.0703619, -2.2476096,
             -1.8941168, -1.5398859, -1.5742403, -2.375471 , -1.9361446],            dtype=float32)
def plot_2d_reconstruction(X, params, model, trained = False):
    X_hat = model.apply(params, X)
    plt.scatter(X[:, 0], X[:, 1], label="Original Data")
    plt.scatter(X_hat[:, 0], X_hat[:, 1], label="Reconstructed Data")
    if trained:
        plt.title("Trained")
    else:
        plt.title("Untrained")
plot_2d_reconstruction(X, params, ae, False)

Define the Loss function

\(\ell_2\) penalty

diff = X - X_hat
diff.shape
(100, 2)
diff[:5]
DeviceArray([[-0.46981597,  5.271835  ],
             [ 1.6502905 ,  3.6781619 ],
             [ 1.8507848 ,  5.0589485 ],
             [ 2.8690844 ,  4.5646677 ],
             [ 0.4905889 ,  2.8893166 ]], dtype=float32)
(diff**2).sum(axis=1).mean() / 2
DeviceArray(7.9555416, dtype=float32)
(diff**2).sum(axis=1)[:5]
DeviceArray([28.01297 , 16.252333, 29.018364, 29.067837,  8.588828], dtype=float32)
(jnp.linalg.norm(diff, ord=2, axis=1) ** 2).mean() / 2
DeviceArray(7.955541, dtype=float32)
from sklearn.metrics import mean_squared_error
mean_squared_error(X, X_hat)
7.9555407
print(2 * optax.l2_loss(X_hat, X).mean())

"""

Multplying by two
Docstring says:
Calculates the L2 loss for a set of predictions.

Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning"
by Bishop, but not "The Elements of Statistical Learning" by Tibshirani.
"""
7.9555416
'\n\nMultplying by two\nDocstring says:\nCalculates the L2 loss for a set of predictions.\n\nNote: the 0.5 term is standard in "Pattern Recognition and Machine Learning"\nby Bishop, but not "The Elements of Statistical Learning" by Tibshirani.\n'
@jax.jit
def loss(params, X):
    X_hat = ae.apply(params, X)
    return 2 * optax.l2_loss(X_hat, X).mean()
loss(params, X)
DeviceArray(7.9555416, dtype=float32)

Defining the train function

def train(
    X: jnp.array,
    optimizer: optax._src.base.GradientTransformation,
    model: nn.Module,
    key_param: jax.random.PRNGKey,
    n_iter: int=500,
    print_every: int=10
):
    loss_array  = np.zeros(n_iter)
    def loss(params, X):
        X_hat = model.apply(params, X)
        return 2 * optax.l2_loss(X_hat, X).mean()

    params = model.init(key_param, X)
    opt_state = optimizer.init(params)
    loss_grad_fn = jax.value_and_grad(loss)

    for i in range(n_iter):
        loss_val, grads = loss_grad_fn(params, X)
        loss_array[i] = loss_val.item()
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        if i % print_every == 0:
            print("Loss step {}: ".format(i), loss_val)
    return params, loss_array
optimized_params, loss_array = train(
    X, optax.adam(learning_rate=0.1), ae, jax.random.PRNGKey(0), n_iter=30
)
Loss step 0:  7.9555416
Loss step 10:  1.3104575
Loss step 20:  0.544944
plt.plot(loss_array)
plt.xlabel("Iterations")
_ = plt.ylabel("Reconstruction loss")

plot_2d_reconstruction(X, optimized_params, ae, True)

from sklearn import datasets
digits = datasets.load_digits()
X = jnp.array(digits["data"])
y = digits["target"]
X.shape
(1797, 64)
plt.imshow(X[1].reshape(8, 8), cmap="Greys")
y[1]
1

bn = 2
ae_digits = AE(bn, X.shape[1])

ae_digits
AE(
    # attributes
    bottleneck = 2
    out = 64
)
print(ae_digits.tabulate(random.PRNGKey(0), X))
                                   AE Summary                                   

┏━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓

┃ path            module   inputs          outputs          params         ┃

┡━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩

│                │ AE      │ float32[1797,… │ float32[1797,6… │                │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│ encoder        │ Encoder │ float32[1797,… │ float32[1797,2] │                │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│ encoder/Dense… │ Dense   │ float32[1797,… │ float32[1797,5] │ bias:          │

│                │         │                │                 │ float32[5]     │

│                │         │                │                 │ kernel:        │

│                │         │                │                 │ float32[64,5]  │

│                │         │                │                 │                │

│                │         │                │                 │ 325 (1.3 KB)   │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│ encoder/Dense… │ Dense   │ float32[1797,… │ float32[1797,2] │ bias:          │

│                │         │                │                 │ float32[2]     │

│                │         │                │                 │ kernel:        │

│                │         │                │                 │ float32[5,2]   │

│                │         │                │                 │                │

│                │         │                │                 │ 12 (48 B)      │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│ decoder        │ Decoder │ float32[1797,… │ float32[1797,6… │                │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│ decoder/Dense… │ Dense   │ float32[1797,… │ float32[1797,5] │ bias:          │

│                │         │                │                 │ float32[5]     │

│                │         │                │                 │ kernel:        │

│                │         │                │                 │ float32[2,5]   │

│                │         │                │                 │                │

│                │         │                │                 │ 15 (60 B)      │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│ decoder/Dense… │ Dense   │ float32[1797,… │ float32[1797,6… │ bias:          │

│                │         │                │                 │ float32[64]    │

│                │         │                │                 │ kernel:        │

│                │         │                │                 │ float32[5,64]  │

│                │         │                │                 │                │

│                │         │                │                 │ 384 (1.5 KB)   │

├────────────────┼─────────┼────────────────┼─────────────────┼────────────────┤

│                                                    Total  736 (2.9 KB)   │

└────────────────┴─────────┴────────────────┴─────────────────┴────────────────┘

                                                                                

                         Total Parameters: 736 (2.9 KB)                         




params_digits = ae_digits.init(random.PRNGKey(0), X)
jax.tree_util.tree_map(lambda x: x.shape, params_digits)
FrozenDict({
    params: {
        decoder: {
            Dense_0: {
                bias: (5,),
                kernel: (2, 5),
            },
            Dense_1: {
                bias: (64,),
                kernel: (5, 64),
            },
        },
        encoder: {
            Dense_0: {
                bias: (5,),
                kernel: (64, 5),
            },
            Dense_1: {
                bias: (2,),
                kernel: (5, 2),
            },
        },
    },
})
def plot_encoding_2dim(encoder, params):
    assert encoder.bottleneck >= 2
    X_low = encoder.apply({"params": params["params"]["encoder"]}, X)
    df = pd.DataFrame(X_low)
    df["label"] = y
    sns.pairplot(df, hue="label", palette="bright")

Untrained encodings

plot_encoding_2dim(Encoder(bottleneck=bn), params_digits)

X_recon = ae_digits.apply(params_digits, X)
def plot_orig_recon(index=0):
    fig, ax = plt.subplots(sharex=True, ncols=2)
    ax[0].imshow(X[index].reshape(8, 8), cmap="Greys")
    ax[1].imshow(X_recon[index].reshape(8, 8), cmap="Greys")
    ax[0].set_title("Original")
    ax[1].set_title("Reconstructed")
plot_orig_recon(5)

optimized_params_digits, loss_array_digits = train(
    X, optax.adam(learning_rate=0.01), ae_digits, jax.random.PRNGKey(0), n_iter=1000
)
Loss step 0:  90.91908
Loss step 10:  62.609577
Loss step 20:  58.390884
Loss step 30:  53.54514
Loss step 40:  45.062607
Loss step 50:  33.541103
Loss step 60:  25.167671
Loss step 70:  21.107908
Loss step 80:  19.424128
Loss step 90:  18.734087
Loss step 100:  18.47802
Loss step 110:  18.390646
Loss step 120:  18.352455
Loss step 130:  18.333141
Loss step 140:  18.321236
Loss step 150:  18.311743
Loss step 160:  18.3032
Loss step 170:  18.295115
Loss step 180:  18.287226
Loss step 190:  18.279234
Loss step 200:  18.270723
Loss step 210:  18.26098
Loss step 220:  18.2499
Loss step 230:  18.237106
Loss step 240:  18.221647
Loss step 250:  18.20243
Loss step 260:  18.177717
Loss step 270:  18.14539
Loss step 280:  18.105865
Loss step 290:  18.058249
Loss step 300:  18.000141
Loss step 310:  17.931208
Loss step 320:  17.84967
Loss step 330:  17.755304
Loss step 340:  17.65073
Loss step 350:  17.537819
Loss step 360:  17.418528
Loss step 370:  17.293976
Loss step 380:  17.164043
Loss step 390:  17.029558
Loss step 400:  16.89464
Loss step 410:  16.760334
Loss step 420:  16.626553
Loss step 430:  16.493797
Loss step 440:  16.362513
Loss step 450:  16.234201
Loss step 460:  16.11052
Loss step 470:  15.992949
Loss step 480:  15.883502
Loss step 490:  15.783846
Loss step 500:  15.694724
Loss step 510:  15.615571
Loss step 520:  15.54589
Loss step 530:  15.483993
Loss step 540:  15.427973
Loss step 550:  15.376085
Loss step 560:  15.326871
Loss step 570:  15.280196
Loss step 580:  15.23521
Loss step 590:  15.191253
Loss step 600:  15.149132
Loss step 610:  15.109302
Loss step 620:  15.071858
Loss step 630:  15.037474
Loss step 640:  15.005837
Loss step 650:  14.977009
Loss step 660:  14.950782
Loss step 670:  14.927103
Loss step 680:  14.905551
Loss step 690:  14.885867
Loss step 700:  14.867877
Loss step 710:  14.851396
Loss step 720:  14.836317
Loss step 730:  14.8224125
Loss step 740:  14.809575
Loss step 750:  14.797547
Loss step 760:  14.786259
Loss step 770:  14.775562
Loss step 780:  14.76545
Loss step 790:  14.755904
Loss step 800:  14.746771
Loss step 810:  14.738021
Loss step 820:  14.729595
Loss step 830:  14.721415
Loss step 840:  14.713423
Loss step 850:  14.705618
Loss step 860:  14.697898
Loss step 870:  14.690201
Loss step 880:  14.682494
Loss step 890:  14.674812
Loss step 900:  14.667133
Loss step 910:  14.6593275
Loss step 920:  14.651322
Loss step 930:  14.643042
Loss step 940:  14.634569
Loss step 950:  14.625735
Loss step 960:  14.616413
Loss step 970:  14.6066065
Loss step 980:  14.596094
Loss step 990:  14.58464
plt.plot(loss_array_digits)

Trained encodings

plot_encoding_2dim(Encoder(bottleneck=bn), optimized_params_digits)

Reconstruction

X_recon = ae_digits.apply(optimized_params_digits, X)
plot_orig_recon(4)

X_reconstructed = ae.apply(params, X)
errs = jnp.square(X - X_reconstructed).sum(axis=1)
err_df = pd.DataFrame({"error": errs, "label": y})
err_df.groupby("label").mean()
error
label
0 1067.159668
1 1253.397217
2 1187.446655
3 730.839417
4 919.732239
5 1103.442505
6 913.172607
7 1309.424438
8 892.981750
9 891.891907
err_df = pd.DataFrame({"error": errs, "label": y})
err_df.groupby("label").mean()
error
label
0 1067.159668
1 1253.397217
2 1187.446655
3 730.839417
4 919.732239
5 1103.442505
6 913.172607
7 1309.424438
8 892.981750
9 891.891907

Convoluational AE

class ConvEncoder(nn.Module):
    bottleneck: int

    @nn.compact
    def __call__(self, x):
        n = x.shape[0]  # x is nx64
        x = x.reshape(n, 8, 8, 1)
        x = nn.Conv(features=4, kernel_size=(2, 2), strides=1, padding=0)(
            x
        )  # 8X8X1 -> 6x6X4
        x = nn.selu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))  # 6x6x4 --> 3x3x4
        x = nn.selu(x)
        x = x.reshape(n, -1)  # N X 3x3x4 -> N X 36
        x = nn.Dense(self.bottleneck)(x)
        return x
ce = ConvEncoder(2)
#print(ce.tabulate(random.PRNGKey(0), X))
print(ce.tabulate(random.PRNGKey(0), X, console_kwargs={"width": 120}))
                                      ConvEncoder Summary                                       

┏━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓

┃ path     module       inputs               outputs              params                   ┃

┡━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩

│         │ ConvEncoder │ float32[1797,64]    │ float32[1797,2]     │                          │

├─────────┼─────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ Conv_0  │ Conv        │ float32[1797,8,8,1] │ float32[1797,7,7,4] │ bias: float32[4]         │

│         │             │                     │                     │ kernel: float32[2,2,1,4] │

│         │             │                     │                     │                          │

│         │             │                     │                     │ 20 (80 B)                │

├─────────┼─────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ Dense_0 │ Dense       │ float32[1797,36]    │ float32[1797,2]     │ bias: float32[2]         │

│         │             │                     │                     │ kernel: float32[36,2]    │

│         │             │                     │                     │                          │

│         │             │                     │                     │ 74 (296 B)               │

├─────────┼─────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│                                                          Total  94 (376 B)               │

└─────────┴─────────────┴─────────────────────┴─────────────────────┴──────────────────────────┘

                                                                                                

                                  Total Parameters: 94 (376 B)                                  




class ConvDecoder(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(36)(x)  # Nx2 --> Nx36
        x = nn.selu(x)
        x = x.reshape(-1, 3, 3, 4)  # NX3X3X4
        x = nn.ConvTranspose(features=4, kernel_size=(2, 2), strides=(2, 2))(
            x
        )  # 3x3x4 -> 6x6X4
        x = nn.selu(x)
        x = nn.Conv(features=1, kernel_size=(1, 1), strides=1, padding=1)(
            x
        )  # 6x6x4 -> 8x8x1
        x = x.reshape(-1, 64)
        return x
cd = ConvDecoder()
print(
    cd.tabulate(
        random.PRNGKey(0),
        jax.random.normal(key=jax.random.PRNGKey(0), shape=(1797, 2)),
        console_kwargs={"width": 120},
    )
)
                                           ConvDecoder Summary                                            

┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓

┃ path             module         inputs               outputs              params                   ┃

┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩

│                 │ ConvDecoder   │ float32[1797,2]     │ float32[1797,64]    │                          │

├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ Dense_0         │ Dense         │ float32[1797,2]     │ float32[1797,36]    │ bias: float32[36]        │

│                 │               │                     │                     │ kernel: float32[2,36]    │

│                 │               │                     │                     │                          │

│                 │               │                     │                     │ 108 (432 B)              │

├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ ConvTranspose_0 │ ConvTranspose │ float32[1797,3,3,4] │ float32[1797,6,6,4] │ bias: float32[4]         │

│                 │               │                     │                     │ kernel: float32[2,2,4,4] │

│                 │               │                     │                     │                          │

│                 │               │                     │                     │ 68 (272 B)               │

├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ Conv_0          │ Conv          │ float32[1797,6,6,4] │ float32[1797,8,8,1] │ bias: float32[1]         │

│                 │               │                     │                     │ kernel: float32[1,1,4,1] │

│                 │               │                     │                     │                          │

│                 │               │                     │                     │ 5 (20 B)                 │

├─────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│                                                                    Total  181 (724 B)              │

└─────────────────┴───────────────┴─────────────────────┴─────────────────────┴──────────────────────────┘

                                                                                                          

                                      Total Parameters: 181 (724 B)                                       




class ConvAE(nn.Module):
    bottleneck: int

    def setup(self):
        # Alternative to @nn.compact -> explicitly define modules
        # Better for later when we want to access the encoder and decoder explicitly
        self.encoder = ConvEncoder(bottleneck=self.bottleneck)
        self.decoder = ConvDecoder()

    def __call__(self, x):

        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
cae = ConvAE(2)
print(
    cae.tabulate(
        random.PRNGKey(0),
        X,
        console_kwargs={"width": 120},
    )
)
                                                  ConvAE Summary                                                  

┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓

┃ path                     module         inputs               outputs              params                   ┃

┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩

│                         │ ConvAE        │ float32[1797,64]    │ float32[1797,64]    │                          │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ encoder                 │ ConvEncoder   │ float32[1797,64]    │ float32[1797,2]     │                          │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ encoder/Conv_0          │ Conv          │ float32[1797,8,8,1] │ float32[1797,7,7,4] │ bias: float32[4]         │

│                         │               │                     │                     │ kernel: float32[2,2,1,4] │

│                         │               │                     │                     │                          │

│                         │               │                     │                     │ 20 (80 B)                │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ encoder/Dense_0         │ Dense         │ float32[1797,36]    │ float32[1797,2]     │ bias: float32[2]         │

│                         │               │                     │                     │ kernel: float32[36,2]    │

│                         │               │                     │                     │                          │

│                         │               │                     │                     │ 74 (296 B)               │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ decoder                 │ ConvDecoder   │ float32[1797,2]     │ float32[1797,64]    │                          │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ decoder/Dense_0         │ Dense         │ float32[1797,2]     │ float32[1797,36]    │ bias: float32[36]        │

│                         │               │                     │                     │ kernel: float32[2,36]    │

│                         │               │                     │                     │                          │

│                         │               │                     │                     │ 108 (432 B)              │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ decoder/ConvTranspose_0 │ ConvTranspose │ float32[1797,3,3,4] │ float32[1797,6,6,4] │ bias: float32[4]         │

│                         │               │                     │                     │ kernel: float32[2,2,4,4] │

│                         │               │                     │                     │                          │

│                         │               │                     │                     │ 68 (272 B)               │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│ decoder/Conv_0          │ Conv          │ float32[1797,6,6,4] │ float32[1797,8,8,1] │ bias: float32[1]         │

│                         │               │                     │                     │ kernel: float32[1,1,4,1] │

│                         │               │                     │                     │                          │

│                         │               │                     │                     │ 5 (20 B)                 │

├─────────────────────────┼───────────────┼─────────────────────┼─────────────────────┼──────────────────────────┤

│                                                                            Total  275 (1.1 KB)             │

└─────────────────────────┴───────────────┴─────────────────────┴─────────────────────┴──────────────────────────┘

                                                                                                                  

                                          Total Parameters: 275 (1.1 KB)                                          




params = cae.init(random.PRNGKey(0), X)
plot_encoding_2dim(ConvEncoder(bottleneck=2), params)

optimized_params_digits_cae, loss_array_digits_cae = train(
    X, optax.adam(learning_rate=0.01), cae, jax.random.PRNGKey(0), n_iter=1000, print_every=50
)
Loss step 0:  61.916904
Loss step 50:  30.379993
Loss step 100:  27.855324
Loss step 150:  26.851124
Loss step 200:  25.77603
Loss step 250:  25.184359
Loss step 300:  24.772747
Loss step 350:  24.351847
Loss step 400:  24.091908
Loss step 450:  23.887573
Loss step 500:  23.72832
Loss step 550:  23.607725
Loss step 600:  23.514961
Loss step 650:  23.419945
Loss step 700:  23.363184
Loss step 750:  23.30127
Loss step 800:  23.258532
Loss step 850:  23.206999
Loss step 900:  23.162285
Loss step 950:  23.13027
plot_encoding_2dim(ConvEncoder(bottleneck=2), optimized_params_digits_cae)

BayesOpt for optimizing the latent dimension

def black_box_function(x, y):
    """Function with unknown internals we wish to maximize.

    This is just serving as an example, for all intents and
    purposes think of the internals of this function, i.e.: the process
    which generates its output values, as unknown.
    """
    x = int(x)
    y = int(y)
    return function_discrete(x, y)
def function_discrete(x, y):
    assert type(x) ==int
    return -(x**2) - (y - 1) ** 2 + 1
pbounds = {"x": (2, 4), "y": (-3, 3)}
optimizer = BayesianOptimization(
    f=black_box_function,
    pbounds=pbounds,
    verbose=2,  # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
    random_state=1,
)
optimizer.maximize()
|   iter    |  target   |     x     |     y     |

-------------------------------------------------

| 1         | -3.0      | 2.834     | 1.322     |

| 2         | -7.0      | 2.0       | -1.186    |

| 3         | -12.0     | 2.294     | -2.446    |

| 4         | -4.0      | 2.373     | -0.9266   |

| 5         | -4.0      | 2.794     | 0.2329    |

| 6         | -15.0     | 4.0       | 1.331     |

| 7         | -4.0      | 2.348     | 0.8879    |

| 8         | -3.0      | 2.797     | 1.257     |

| 9         | -4.0      | 2.064     | 2.229     |

| 10        | -9.0      | 3.657     | -0.9428   |

| 11        | -7.0      | 2.901     | 3.0       |

| 12        | -4.0      | 2.0       | -0.1486   |

| 13        | -31.0     | 4.0       | -3.0      |

| 14        | -7.0      | 2.0       | 3.0       |

| 15        | -3.0      | 2.0       | 1.539     |

| 16        | -3.0      | 2.512     | 1.792     |

| 17        | -19.0     | 4.0       | 3.0       |

| 18        | -4.0      | 2.831     | -0.4655   |

| 19        | -4.0      | 2.402     | -0.3286   |

| 20        | -9.0      | 3.539     | 0.08748   |

| 21        | -7.0      | 2.841     | -1.217    |

| 22        | -4.0      | 2.764     | 2.245     |

| 23        | -4.0      | 2.0       | 0.4436    |

| 24        | -3.0      | 2.469     | 1.423     |

| 25        | -3.0      | 2.0       | 1.16      |

| 26        | -3.0      | 2.787     | 1.714     |

| 27        | -4.0      | 2.932     | 0.7853    |

| 28        | -3.0      | 2.647     | 1.526     |

| 29        | -3.0      | 2.148     | 1.373     |

| 30        | -3.0      | 2.212     | 1.795     |

=================================================
optimizer.max
{'target': -3.0, 'params': {'x': 2.8340440094051482, 'y': 1.3219469606529488}}
{k: int(v) for k, v in optimizer.max["params"].items()}
{'x': 2, 'y': 1}
function_discrete(2, 1)
-3

Let us keep a separate validation set

def loss_model(params, X, model):
    X_hat = model.apply(params, X)
    diff = X - X_hat
    return (diff**2).sum(axis=1).mean() / X.shape[1]
from functools import partial

e = partial(loss_model, model=cae)
e(params, X)
DeviceArray(61.916904, dtype=float32)
def validation_loss_discrete(bn):
    assert type(bn) == int

    # Train the model on bn sized bottleneck
    cae = ConvAE(bn)
    loss_fn_concrete = jax.jit(partial(loss_model, model=cae))
    loss_grad_fn = jax.value_and_grad(loss_fn_concrete)
    tx = optax.adam(learning_rate=1e-2)
    params = cae.init(random.PRNGKey(0), X_train)
    opt_state = tx.init(params)
    print(f"--------Bottleneck of Size: {bn}-------------")
    for i in range(30):
        loss_val, grads = loss_grad_fn(params, X_train)
        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if i % 5 == 0:
            print("Loss step {}: ".format(i), loss_val)
    print(f"--------End-------------")

    # Evaluate on validation dataset
    return loss_fn_concrete(params, X_validation)
X_train, X_validation = X[:1000], X[1000:]
validation_loss_discrete(2)
--------Bottleneck of Size: 2-------------
Loss step 0:  62.27715
Loss step 5:  58.5037
Loss step 10:  53.984245
Loss step 15:  49.513382
Loss step 20:  43.078316
Loss step 25:  38.30596
--------End-------------
DeviceArray(36.75615, dtype=float32)
def validation_loss_bb(bn):
    bn_int = int(bn)
    return -validation_loss_discrete(bn_int)
validation_loss_bb(2.5)
--------Bottleneck of Size: 2-------------
Loss step 0:  62.27715
Loss step 5:  58.5037
Loss step 10:  53.984245
Loss step 15:  49.513382
Loss step 20:  43.078316
Loss step 25:  38.30596
--------End-------------
DeviceArray(-36.75615, dtype=float32)
pbounds = {"bn": (1, 40)}
optimizer = BayesianOptimization(
    f=validation_loss_bb,
    pbounds=pbounds,
    verbose=2,  # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
    random_state=1,
)
optimizer.maximize(n_iter=8)
|   iter    |  target   |    bn     |

-------------------------------------

--------Bottleneck of Size: 17-------------

Loss step 0:  62.85297

Loss step 5:  52.85449

Loss step 10:  40.903214

Loss step 15:  35.32036

Loss step 20:  35.3193

Loss step 25:  33.33418

--------End-------------

| 1         | -32.36    | 17.26     |

--------Bottleneck of Size: 29-------------

Loss step 0:  64.064514

Loss step 5:  53.85875

Loss step 10:  47.26749

Loss step 15:  43.828564

Loss step 20:  41.847286

Loss step 25:  39.23966

--------End-------------

| 2         | -37.29    | 29.09     |

--------Bottleneck of Size: 1-------------

Loss step 0:  60.969757

Loss step 5:  58.92785

Loss step 10:  53.683678

Loss step 15:  49.58035

Loss step 20:  45.86102

Loss step 25:  44.17104

--------End-------------

| 3         | -42.48    | 1.004     |

--------Bottleneck of Size: 12-------------

Loss step 0:  63.704227

Loss step 5:  57.338806

Loss step 10:  49.537926

Loss step 15:  41.210827

Loss step 20:  38.469257

Loss step 25:  35.276833

--------End-------------

| 4         | -34.07    | 12.79     |

--------Bottleneck of Size: 6-------------

Loss step 0:  61.450924

Loss step 5:  55.82548

Loss step 10:  47.88899

Loss step 15:  40.131763

Loss step 20:  37.62544

Loss step 25:  35.873016

--------End-------------

| 5         | -34.2     | 6.723     |

--------Bottleneck of Size: 20-------------

Loss step 0:  61.81845

Loss step 5:  56.358246

Loss step 10:  51.92751

Loss step 15:  47.312576

Loss step 20:  42.146885

Loss step 25:  37.025486

--------End-------------

| 6         | -33.86    | 20.39     |

--------Bottleneck of Size: 40-------------

Loss step 0:  61.5667

Loss step 5:  49.598972

Loss step 10:  42.639145

Loss step 15:  39.22532

Loss step 20:  36.597954

Loss step 25:  34.528015

--------End-------------

| 7         | -32.67    | 40.0      |

--------Bottleneck of Size: 36-------------

Loss step 0:  62.303535

Loss step 5:  52.075367

Loss step 10:  44.435425

Loss step 15:  40.889286

Loss step 20:  39.280178

Loss step 25:  37.09512

--------End-------------

| 8         | -35.77    | 36.05     |

--------Bottleneck of Size: 9-------------

Loss step 0:  63.35566

Loss step 5:  52.45499

Loss step 10:  43.281902

Loss step 15:  37.028984

Loss step 20:  35.006325

Loss step 25:  33.583298

--------End-------------

| 9         | -33.01    | 9.596     |

--------Bottleneck of Size: 24-------------

Loss step 0:  62.888515

Loss step 5:  52.035835

Loss step 10:  42.154068

Loss step 15:  36.804348

Loss step 20:  34.53549

Loss step 25:  32.37921

--------End-------------

| 10        | -30.08    | 24.26     |

--------Bottleneck of Size: 25-------------

Loss step 0:  63.406757

Loss step 5:  50.291225

Loss step 10:  41.73214

Loss step 15:  38.421593

Loss step 20:  37.0491

Loss step 25:  34.847046

--------End-------------

| 11        | -33.89    | 25.81     |

--------Bottleneck of Size: 22-------------

Loss step 0:  62.303898

Loss step 5:  53.713398

Loss step 10:  47.806355

Loss step 15:  43.550034

Loss step 20:  42.033653

Loss step 25:  39.68766

--------End-------------

| 12        | -38.51    | 22.8      |

--------Bottleneck of Size: 24-------------

Loss step 0:  62.888515

Loss step 5:  52.035835

Loss step 10:  42.154068

Loss step 15:  36.804348

Loss step 20:  34.53549

Loss step 25:  32.37921

--------End-------------

| 13        | -30.08    | 24.3      |

=====================================
optimizer.max
{'target': -30.082199096679688, 'params': {'bn': 24.25939633195359}}

VAE

class VAE_Encoder(nn.Module):
    bottleneck: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(5)(x)
        x = nn.selu(x)
        mu = nn.Dense(features=self.bottleneck)(x)
        log_std = nn.Dense(features=self.bottleneck)(x)
        return mu, log_std
def reparameterize(mu, log_std, key=random.PRNGKey(0), samples=1):
    std = jnp.exp(log_std)
    eps = random.normal(key=key, shape=(samples,))
    return mu + eps * std
samples = reparameterize(2, jnp.log(1), samples=5000)
sns.kdeplot(samples)
plt.title(f"Mean:{jnp.mean(samples):0.2f}, stddev: {jnp.std(samples):0.2f}")
Text(0.5, 1.0, 'Mean:2.00, stddev: 1.00')

class VAE(nn.Module):
    bottleneck: int
    out: int

    def setup(self):
        # Alternative to @nn.compact -> explicitly define modules
        # Better for later when we want to access the encoder and decoder explicitly
        self.encoder = VAE_Encoder(bottleneck=self.bottleneck)
        self.decoder = Decoder(out=self.out)

    def __call__(self, x, rng=random.PRNGKey(0)):
        mu, log_std = self.encoder(x)
        z = reparameterize(mu, log_std, key=rng)
        x_hat = self.decoder(z)
        return x_hat, mu, log_std
vae = VAE(bottleneck=2, out=64)
params = vae.init(random.PRNGKey(10), X)
plt.imshow(vae.apply(params, X)[0][0].reshape(8, 8))

vae.apply(params, X, random.PRNGKey(10))[0][0].reshape(8, 8)
DeviceArray([[ -3999.399   ,   6091.6396  ,  -2634.2932  ,    307.47302 ,
                3932.0298  ,   1823.3352  ,   3852.157   ,   5576.5605  ],
             [ -8809.304   ,   5299.91    ,    286.5227  ,   1059.3925  ,
                -951.62537 ,  -6623.4824  ,  -1463.6239  ,  16223.624   ],
             [ -5279.1323  ,  -7333.815   ,    -71.1485  ,   5679.2773  ,
                1384.2794  ,   8326.92    ,  -1747.943   ,  -4802.341   ],
             [   403.3739  ,  13455.688   ,  -7414.195   ,   7299.713   ,
                1180.7408  ,   -328.49432 ,   6619.1357  ,    363.74713 ],
             [ -4376.3506  ,  -2045.3063  ,   2618.412   , -10890.402   ,
               -3035.3848  ,  -3574.7527  ,  -5057.2593  ,  -1859.8529  ],
             [   -53.99241 ,   2318.109   ,  -1323.9087  ,  -6801.4814  ,
               -7300.1553  ,    865.4169  ,  13349.937   ,    865.3773  ],
             [    37.275284,  -3962.8357  ,   1771.9886  ,  -7992.7188  ,
                4896.562   , -17371.383   ,   4737.3887  ,   7307.3384  ],
             [  -221.0234  ,  -5475.8447  ,   4189.172   ,  -1095.9471  ,
               -6452.915   ,   3767.8381  , -10514.758   ,  -2311.0862  ]],            dtype=float32)
vae_e = VAE_Encoder(2)
mu, log_sigma = vae_e.apply({"params": params["params"]["encoder"]}, X)
tfd = tfp.distributions
q
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [99], in <cell line: 1>()
----> 1 q

NameError: name 'q' is not defined
tfd.kl_divergence(q, p).shape
tfd.kl_divergence(q, p).mean()
q.stddev()

Loss

@jax.jit
def loss_vae(params, X, rng=random.PRNGKey(0)):
    X_hat, mu, log_sigma = vae.apply(params, X, rng)
    q = tfd.Normal(loc=mu, scale=jnp.exp(log_sigma))
    p = tfd.Normal(loc=0.0, scale=1.0)
    kl_loss = tfd.kl_divergence(q, p).mean()

    diff = X - X_hat
    recon_loss = (diff**2).sum(axis=1).mean() / X.shape[1]

    return recon_loss + 0.0020 * kl_loss
loss_vae(params, X, random.PRNGKey(4))
import optax

learning_rate = 0.01
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss_vae)
for i in range(2001):
    rng, key = random.split(rng)
    loss_val, grads = loss_grad_fn(params, X, rng)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 50 == 0:
        print("Loss step {}: ".format(i), loss_val)
X_recon, _, _ = vae.apply(params, X)
plot_orig_recon(8)
dec = Decoder(out=64)
N = 10
x_range = jnp.linspace(-2, 2, N)
fig, ax = plt.subplots(ncols=N, sharey=True, figsize=(20, 4))
for i in range(N):
    ax[i].imshow(
        dec.apply(
            {"params": params["params"]["decoder"]}, jnp.array([x_range[i], 0.0])
        ).reshape(8, 8),
        cmap="Greys",
    )
def plot_encoding_2dim_vae(encoder, params):
    assert encoder.bottleneck >= 2
    mu, log_sigma = encoder.apply({"params": params["params"]["encoder"]}, X)
    df = pd.DataFrame(mu)
    df["label"] = y
    sns.pairplot(df, hue="label", palette="bright")
vae_enc = VAE_Encoder(2)
mu, log_sigma = vae_enc.apply({"params": params["params"]["encoder"]}, X)
# plot_encoding_2dim_vae(VAE_Encoder(2), params)
plot_encoding_2dim_vae(vae_enc, params)

TODO

  • regular AE: Bayesopt for latent dimension
  • generation from regular AE
  • graph of reconstruction loss v/s latent dimension for regular AE
  • GIF for walking in latent space for VAE
  • Reconstruction as a factor of Recon + Beta X KL
  • Get the Encoder from AE object directly
  • Impact of MC samples
  • Reconstruction v/s Expected Log Likelihood (confirm the trend is same for both)
  • Cleanup code so that can be reused rather than copy pasting
  • Sparse VAE
  • Add references
  • Add bib entry
  • Consider CNNs for more realistic datasets
  1. https://lilianweng.github.io/posts/2018-08-12-vae/
  2. https://theaisummer.com/jax-tensorflow-pytorch/
  3. https://dmol.pub/dl/VAE.html