import numpy as np
from scipy.stats import multivariate_normal
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib
#matplotlib.rcParams['figure.figsize'] = (8,6)
from matplotlib import pyplot as plt
import GPy
In this notebook, we cover multi-output GPs. The presentation follows the excellent video from GPSS
ICM
\(u \sim GP (0, k)\)
sample from u to get a sample \(u^1\)
\(f_1(x) = a^1_1 u^1(x)\)
\(f_2(x) = a^1_2 u^1(x)\)
= np.linspace(-3.,3., 50) X
= GPy.kern.RBF(input_dim=1, variance=1., lengthscale=2.) kernel
kernel
rbf. | value | constraints | priors |
variance | 1.0 | +ve | |
lengthscale | 2.0 | +ve |
def jitter(C, j = 1e-6):
return C + np.eye(len(C))*j
= jitter(kernel.K(X.reshape(-1, 1))) cov
plt.imshow(cov)
= multivariate_normal(cov=cov) mvn
= mvn.rvs(random_state=0)
u1 plt.plot(X, u1)
= 0.9
a11 = 0.7
a12
= np.array([a11, a12]).reshape(-1, 1) a
a
array([[0.9],
[0.7]])
= a@a.T
B B
array([[0.81, 0.63],
[0.63, 0.49]])
= np.kron(B, cov) cov_f
='Purples')
plt.imshow(cov_f, cmap plt.colorbar()
= multivariate_normal(cov=jitter(cov_f)).rvs(size=500)
f_sample = f_sample[:, :50], f_sample[:, 50:] f1_samples, f2_samples
#plt.plot(X, u1, label="u1")
for i in range(2):
='g')
plt.plot(X, f1_samples[i], color='r' ) plt.plot(X, f2_samples[i], color
/f2_samples[i] f1_samples[i]
array([1.28521323, 1.2870487 , 1.28169798, 1.29387391, 1.28381124,
1.29063798, 1.28399272, 1.28787108, 1.27634933, 1.29367057,
1.19405718, 0.81421541, 1.29366628, 1.23932848, 1.28601429,
1.31178054, 1.27596873, 1.28139033, 1.28548127, 1.28874727,
1.288544 , 1.28851575, 1.27706874, 1.28929381, 1.27167387,
1.30216154, 1.28769528, 1.28397652, 1.2896767 , 1.29357874,
1.28743778, 1.28867757, 1.29135504, 1.28085954, 1.27832016,
1.29113682, 1.28346876, 1.28115477, 1.28579679, 1.28664088,
1.2836771 , 1.28690568, 1.28521466, 1.28474094, 1.28147929,
1.28752966, 1.28577663, 1.28154063, 1.28312776, 1.2869964 ])
## Learning in MOGP setting
= f1_samples[4]
f1_dataset = f2_samples[4] f2_dataset
='f1')
plt.plot(X, f1_dataset, label='f2')
plt.plot(X, f2_dataset, label plt.legend()
## What all we want to learn:
# 1. GP kernel parameters
# 2. a11, a12
import jax
import jax.numpy as jnp
from jax.config import config
"jax_enable_x64", True)
config.update(import tensorflow_probability.substrates.jax as tfp
= jnp.hstack([f1_dataset, f2_dataset]) f
def sqexp(a, b, var=1.0, ls=4):
= (a-b)/ls
diff = jnp.sum(diff ** 2)
d return var*jnp.exp(-0.5 * d)
def all_pairs(f):
= jax.vmap(f, in_axes= (None, 0, None, None))
f = jax. vmap (f, in_axes= (0, None, None, None))
f return f
-1, 1)) kernel.K(X.reshape(
array([[1. , 0.99812754, 0.99253116, ..., 0.01592046, 0.01332383,
0.011109 ],
[0.99812754, 1. , 0.99812754, ..., 0.01895197, 0.01592046,
0.01332383],
[0.99253116, 0.99812754, 1. , ..., 0.02247631, 0.01895197,
0.01592046],
...,
[0.01592046, 0.01895197, 0.02247631, ..., 1. , 0.99812754,
0.99253116],
[0.01332383, 0.01592046, 0.01895197, ..., 0.99812754, 1. ,
0.99812754],
[0.011109 , 0.01332383, 0.01592046, ..., 0.99253116, 0.99812754,
1. ]])
1.0, 2.0)), kernel.K(X.reshape(-1, 1))) np.allclose(np.array(all_pairs(sqexp)(X, X,
True
= 1
rank = 2
output_dim = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/10.0
A @A.T, A A
(DeviceArray([[ 0.03298171, -0.01370936],
[-0.01370936, 0.00569851]], dtype=float64),
DeviceArray([[ 0.18160867],
[-0.07548848]], dtype=float64))
= 2
output_dim = 4
rank = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/2.0
A @A.T A
DeviceArray([[ 1.24957827, -0.04698574],
[-0.04698574, 0.57577417]], dtype=float64)
def covariance_f(var, ls, A):
"""
A: (output_dim, rank)
A can be generated as:
A = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))
"""
= A@A.T
B = all_pairs(sqexp)(X, X, var, ls)
cov = jitter(jnp.kron(B, cov))
cov_f return cov_f
def cost(var, ls, A):
= covariance_f(var, ls, A)
cov_f = tfp.distributions.MultivariateNormalFullCovariance(loc = jnp.zeros_like(f), covariance_matrix = cov_f)
dist return -dist.log_prob(f)
1.0, 2.0, A), cmap='Purples')
plt.imshow(covariance_f( plt.colorbar()
1.0, 2.0, A) cost(
DeviceArray(-431.60947116, dtype=float64)
1.0, 1.0, A) cost(
DeviceArray(-387.35267033, dtype=float64)
= jax.grad(cost, argnums=[0, 1, 2])(0.1, 1.0, A)
grads
= 0.1
var = 1.0
ls
= 1e-3 lr
for i in range(500):
= jax.grad(cost, argnums=[0, 1, 2])(var, ls, A)
grads = var-lr*grads[0]
var = ls-lr*grads[1]
ls = A-lr*grads[2]
A if i%100==0:
print(i, cost(1.0, 1.0, A), var, ls)
0 -387.06097276826193 0.500429427376359 1.0913929924306696
100 -306.72979544101435 3.6414838350262055 2.363476650308803
200 -305.64842462218047 3.514293617054404 2.3873529546968477
300 -304.7976816183849 3.379382170959892 2.403204858135416
400 -304.0941499412901 3.236859846397818 2.4140771572105426
= covariance_f(var, ls, A)
C_learnt ='Purples')
plt.imshow(C_learnt, cmap plt.colorbar()
= tfp.distributions.MultivariateNormalFullCovariance(covariance_matrix=C_learnt)
dist = dist.sample(sample_shape=(10, ), seed = jax.random.PRNGKey(0))
samples_f1 for s in samples_f1:
50], color='k')
plt.plot(X, s[: plt.plot(X, f1_dataset)
SLFM
def covariance_f_SLFM(var1, ls1, A1, var2, ls2, A2):
"""
"""
= A1@A1.T
B1 = A2@A2.T
B2 = all_pairs(sqexp)(X, X, var1, ls1)
cov1 = all_pairs(sqexp)(X, X, var1, ls1)
cov2 = jitter(jnp.kron(B1, cov1) + jnp.kron(B2, cov2))
cov_f return cov_f
= 1
rank = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/2.0
a1 = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/2.0
a2
= covariance_f_SLFM(1.0, 2.0, a1@a1.T, 1.0, 4.0, a2@a2.T) C_SLFM
='Purples')
plt.imshow(C_SLFM, cmap plt.colorbar()