Basic Imports

from silence_tensorflow import silence_tensorflow
silence_tensorflow()

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import functools
import seaborn as sns
import tensorflow_probability as tfp
import pandas as pd

tfd = tfp.distributions
tfl = tfp.layers
tfb = tfp.bijectors

sns.reset_defaults()
sns.set_context(context="talk", font_scale=1)
%matplotlib inline
%config InlineBackend.figure_format='retina'
np.random.seed(0)
tf.random.set_seed(0)
def lr(x, stddv_datapoints):
    num_datapoints, data_dim = x.shape
    b = yield tfd.Normal(
        loc=0.0,
        scale=2.0,
        name="b",
    )
    w = yield tfd.Normal(
        loc=tf.zeros([data_dim]), scale=2.0 * tf.ones([data_dim]), name="w"
    )

    y = yield tfd.Normal(
        loc=tf.linalg.matvec(x, w) + b, scale=stddv_datapoints, name="y"
    )
x = tf.linspace(-5.0, 5.0, 100)
x = tf.expand_dims(x, 1)
stddv_datapoints = 1

concrete_lr_model = functools.partial(lr, x=x, stddv_datapoints=stddv_datapoints)

model = tfd.JointDistributionCoroutineAutoBatched(concrete_lr_model)
model
<tfp.distributions.JointDistributionCoroutineAutoBatched 'JointDistributionCoroutineAutoBatched' batch_shape=[] event_shape=StructTuple(
  b=[],
  w=[1],
  y=[100]
) dtype=StructTuple(
  b=float32,
  w=float32,
  y=float32
)>
actual_b, actual_w, y_train = model.sample()
plt.scatter(x, y_train, s=30, alpha=0.6)
plt.plot(x, tf.linalg.matvec(x, actual_w) + actual_b, color="k")
sns.despine()
trace_fn = lambda traceable_quantities: {
    "loss": traceable_quantities.loss,
    "w": w,
    "b": b,
}
data_dim = 1
w = tf.Variable(tf.zeros_like(actual_w))

b = tf.Variable(tf.zeros_like(actual_b))

target_log_prob_fn = lambda w, b: model.log_prob((b, w, y_train))
target_log_prob_fn
<function __main__.<lambda>(w, b)>
trace = tfp.math.minimize(
    lambda: -target_log_prob_fn(w, b),
    optimizer=tf.optimizers.Adam(learning_rate=0.05),
    trace_fn=trace_fn,
    num_steps=200,
)
w, b, actual_w, actual_b
(<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.1811483], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0149531>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.1337605], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=3.0221252>)
plt.plot(trace["w"], label="w")
plt.plot(trace["b"], label="b")
plt.legend()
sns.despine()
qw_mean = tf.Variable(tf.random.normal([data_dim]))
qb_mean = tf.Variable(tf.random.normal([1]))
qw_stddv = tfp.util.TransformedVariable(
    1e-4 * tf.ones([data_dim]), bijector=tfb.Softplus()
)
qb_stddv = tfp.util.TransformedVariable(1e-4 * tf.ones([1]), bijector=tfb.Softplus())
def factored_normal_variational_model():
    qw = yield tfd.Normal(loc=qw_mean, scale=qw_stddv, name="qw")
    qb = yield tfd.Normal(loc=qb_mean, scale=qb_stddv, name="qb")


surrogate_posterior = tfd.JointDistributionCoroutineAutoBatched(
    factored_normal_variational_model
)

losses = tfp.vi.fit_surrogate_posterior(
    target_log_prob_fn,
    surrogate_posterior=surrogate_posterior,
    optimizer=tf.optimizers.Adam(learning_rate=0.05),
    num_steps=200,
)
/Users/nipun/miniforge3/lib/python3.9/site-packages/tensorflow_probability/python/internal/vectorization_util.py:87: UserWarning: Saw Tensor seed Tensor("seed:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed.
  warnings.warn(
/Users/nipun/miniforge3/lib/python3.9/site-packages/tensorflow_probability/python/internal/vectorization_util.py:87: UserWarning: Saw Tensor seed Tensor("seed:0", shape=(2,), dtype=int32), implying stateless sampling. Autovectorized functions that use stateless sampling may be quite slow because the current implementation falls back to an explicit loop. This will be fixed in the future. For now, you will likely see better performance from stateful sampling, which you can invoke by passing a Python `int` seed.
  warnings.warn(
qw_mean, qw_stddv, qb_mean, qb_stddv
(<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.1905935], dtype=float32)>,
 <TransformedVariable: name=softplus, dtype=float32, shape=[1], fn="softplus", numpy=array([0.04352505], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([3.0260112], dtype=float32)>,
 <TransformedVariable: name=softplus, dtype=float32, shape=[1], fn="softplus", numpy=array([0.09258726], dtype=float32)>)
s_qw, s_qb = surrogate_posterior.sample(500)
ys = tf.linalg.matvec(x, s_qw) + s_qb
x.shape, ys.shape
(TensorShape([100, 1]), TensorShape([500, 100]))
plt.plot(x, ys.numpy().T, color='k', alpha=0.05);
plt.scatter(x, y_train, s=30, alpha=0.6)
sns.despine()

TODO

  1. How to replace x in lr function from x_train to x_test?