# Multi-output Gaussian Process

Multi-output Gaussian Process
ml
Author

Nipun Batra

Published

October 25, 2022

In this notebook, we cover multi-output GPs. The presentation follows the excellent video from GPSS

import numpy as np
from scipy.stats import multivariate_normal
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib
#matplotlib.rcParams['figure.figsize'] = (8,6)
from matplotlib import pyplot as plt
import GPy

# ICM

$$u \sim GP (0, k)$$

sample from u to get a sample $$u^1$$

$$f_1(x) = a^1_1 u^1(x)$$

$$f_2(x) = a^1_2 u^1(x)$$

X = np.linspace(-3.,3., 50)
kernel = GPy.kern.RBF(input_dim=1, variance=1., lengthscale=2.)
kernel
rbf. valueconstraintspriors
variance 1.0 +ve
lengthscale 2.0 +ve
def jitter(C,  j = 1e-6):
return C + np.eye(len(C))*j
cov = jitter(kernel.K(X.reshape(-1, 1)))
plt.imshow(cov)
<matplotlib.image.AxesImage at 0x198a04e80>

mvn = multivariate_normal(cov=cov)
u1 = mvn.rvs(random_state=0)
plt.plot(X, u1)

a11 = 0.9
a12 = 0.7

a = np.array([a11, a12]).reshape(-1, 1)
a
array([[0.9],
[0.7]])
B = a@a.T
B
array([[0.81, 0.63],
[0.63, 0.49]])
cov_f = np.kron(B, cov)
plt.imshow(cov_f, cmap='Purples')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x198bd9a30>

f_sample = multivariate_normal(cov=jitter(cov_f)).rvs(size=500)
f1_samples, f2_samples = f_sample[:, :50], f_sample[:, 50:]
#plt.plot(X, u1, label="u1")
for i in range(2):
plt.plot(X, f1_samples[i], color='g')
plt.plot(X, f2_samples[i], color='r' )

f1_samples[i]/f2_samples[i]
array([1.28521323, 1.2870487 , 1.28169798, 1.29387391, 1.28381124,
1.29063798, 1.28399272, 1.28787108, 1.27634933, 1.29367057,
1.19405718, 0.81421541, 1.29366628, 1.23932848, 1.28601429,
1.31178054, 1.27596873, 1.28139033, 1.28548127, 1.28874727,
1.288544  , 1.28851575, 1.27706874, 1.28929381, 1.27167387,
1.30216154, 1.28769528, 1.28397652, 1.2896767 , 1.29357874,
1.28743778, 1.28867757, 1.29135504, 1.28085954, 1.27832016,
1.29113682, 1.28346876, 1.28115477, 1.28579679, 1.28664088,
1.2836771 , 1.28690568, 1.28521466, 1.28474094, 1.28147929,
1.28752966, 1.28577663, 1.28154063, 1.28312776, 1.2869964 ])
## Learning in MOGP setting

f1_dataset = f1_samples[4]
f2_dataset  = f2_samples[4]
plt.plot(X, f1_dataset, label='f1')
plt.plot(X, f2_dataset, label='f2')
plt.legend()
<matplotlib.legend.Legend at 0x198ceaeb0>

## What all we want to learn:

# 1. GP kernel parameters
# 2. a11, a12
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import tensorflow_probability.substrates.jax as tfp
f = jnp.hstack([f1_dataset, f2_dataset])
def sqexp(a, b, var=1.0, ls=4):
diff = (a-b)/ls
d = jnp.sum(diff ** 2)
return var*jnp.exp(-0.5 * d)

def all_pairs(f):
f = jax.vmap(f, in_axes= (None, 0, None, None))
f = jax. vmap (f, in_axes= (0, None, None, None))
return f
kernel.K(X.reshape(-1, 1))
array([[1.        , 0.99812754, 0.99253116, ..., 0.01592046, 0.01332383,
0.011109  ],
[0.99812754, 1.        , 0.99812754, ..., 0.01895197, 0.01592046,
0.01332383],
[0.99253116, 0.99812754, 1.        , ..., 0.02247631, 0.01895197,
0.01592046],
...,
[0.01592046, 0.01895197, 0.02247631, ..., 1.        , 0.99812754,
0.99253116],
[0.01332383, 0.01592046, 0.01895197, ..., 0.99812754, 1.        ,
0.99812754],
[0.011109  , 0.01332383, 0.01592046, ..., 0.99253116, 0.99812754,
1.        ]])
np.allclose(np.array(all_pairs(sqexp)(X, X, 1.0, 2.0)), kernel.K(X.reshape(-1, 1)))
True
rank = 1
output_dim = 2
A = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/10.0
A@A.T, A
(DeviceArray([[ 0.03298171, -0.01370936],
[-0.01370936,  0.00569851]], dtype=float64),
DeviceArray([[ 0.18160867],
[-0.07548848]], dtype=float64))
output_dim = 2
rank = 4
A = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/2.0
A@A.T
DeviceArray([[ 1.24957827, -0.04698574],
[-0.04698574,  0.57577417]], dtype=float64)
def covariance_f(var, ls, A):
"""
A: (output_dim, rank)
A can be generated as:
A = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))
"""
B = A@A.T
cov = all_pairs(sqexp)(X, X, var, ls)
cov_f = jitter(jnp.kron(B, cov))
return cov_f
def cost(var, ls, A):
cov_f = covariance_f(var, ls, A)
dist = tfp.distributions.MultivariateNormalFullCovariance(loc = jnp.zeros_like(f), covariance_matrix = cov_f)
return -dist.log_prob(f)
plt.imshow(covariance_f(1.0, 2.0, A), cmap='Purples')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x19c218a60>

cost(1.0, 2.0, A)
DeviceArray(-431.60947116, dtype=float64)
cost(1.0, 1.0, A)
DeviceArray(-387.35267033, dtype=float64)
grads = jax.grad(cost, argnums=[0, 1, 2])(0.1, 1.0, A)

var = 0.1
ls = 1.0

lr = 1e-3
for i in range(500):
if i%100==0:
print(i, cost(1.0, 1.0, A), var, ls)
0 -387.06097276826193 0.500429427376359 1.0913929924306696
100 -306.72979544101435 3.6414838350262055 2.363476650308803
200 -305.64842462218047 3.514293617054404 2.3873529546968477
300 -304.7976816183849 3.379382170959892 2.403204858135416
400 -304.0941499412901 3.236859846397818 2.4140771572105426
C_learnt = covariance_f(var, ls, A)
plt.imshow(C_learnt, cmap='Purples')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x1a1485fd0>

dist = tfp.distributions.MultivariateNormalFullCovariance(covariance_matrix=C_learnt)
samples_f1 = dist.sample(sample_shape=(10, ), seed = jax.random.PRNGKey(0))
for s in samples_f1:
plt.plot(X, s[:50], color='k')
plt.plot(X, f1_dataset)

# SLFM

def covariance_f_SLFM(var1, ls1, A1, var2, ls2, A2):
"""

"""
B1 = A1@A1.T
B2 = A2@A2.T
cov1 = all_pairs(sqexp)(X, X, var1, ls1)
cov2 = all_pairs(sqexp)(X, X, var1, ls1)
cov_f = jitter(jnp.kron(B1, cov1) + jnp.kron(B2, cov2))
return cov_f
rank = 1
a1 = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/2.0
a2 = jax.random.normal(key=jax.random.PRNGKey(0), shape=(output_dim,rank))/2.0

C_SLFM = covariance_f_SLFM(1.0, 2.0, a1@a1.T, 1.0, 4.0, a2@a2.T)
plt.imshow(C_SLFM, cmap='Purples')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x1a1613ee0>