import jax.numpy as jnp
from jax import random, jit, vmap, grad, jacfwd, jacrev, hessian, value_and_grad
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Gradient Descent
Gradient Descent
# Simple 2D quadratic function
def f(theta_0, theta_1):
return theta_0**2 + theta_1**2
# Plot surface and contour plots for f using jax.vmap
def create_plot(f):
= jnp.linspace(-2, 2, 100)
theta_0 = jnp.linspace(-2, 2, 100)
theta_1 = jnp.meshgrid(theta_0, theta_1)
theta_0, theta_1 = jnp.vectorize(f, signature='(),()->()')
f_vmap = f_vmap(theta_0, theta_1)
f_vals
# Create a figure with 2 subplots (3d surface and 2d contour)
= plt.figure(figsize=(12, 4))
fig = fig.add_subplot(121, projection='3d')
ax1 = fig.add_subplot(122)
ax2
# Plot surface and contour plots
= ax1.plot_surface(theta_0, theta_1, f_vals, cmap='viridis')
temp
# Filled contour plot and marked level set values using clabel
# Set 20 levels between min and max of f_vals
= jnp.linspace(0.5, int(jnp.max(f_vals))+0.5, 11)
levels
= ax2.contour(theta_0, theta_1, f_vals, levels=levels, cmap='viridis')
contours =True, fontsize=8)
ax2.clabel(contours, inline
# Fill using imshow
=[-2, 2, -2, 2], origin='lower', cmap='viridis', alpha=0.5)
ax2.imshow(f_vals, extent
# Find the global minimum of f using jax.scipy.optimize.minimize
from jax.scipy.optimize import minimize
def f_min(theta):
return f(theta[0], theta[1])
= minimize(f_min, jnp.array([0., 0.]), method='BFGS')
res = res.x
theta_min = res.fun
f_min print(f'Global minimum: {f_min} at {theta_min}')
# Plot the global minimum
0], theta_min[1], marker='x', color='red', s=100)
ax2.scatter(theta_min[
'equal')
ax2.set_aspect(
# Add labels
r'$\theta_0$')
ax1.set_xlabel(r'$\theta_1$')
ax1.set_ylabel(r'$f(\theta_0, \theta_1)$')
ax1.set_zlabel(r'$\theta_0$')
ax2.set_xlabel(r'$\theta_1$')
ax2.set_ylabel(
# Add colorbar
=ax1, shrink=0.5, aspect=5)
fig.colorbar(temp, ax
# Tight layout
plt.tight_layout()
create_plot(f)
Global minimum: 0.0 at [0. 0.]
# Gradient of f at a given point
def grad_f(theta_0, theta_1):
return grad(f, argnums=(0, 1))(theta_0, theta_1)
2., 1.) grad_f(
(Array(4., dtype=float32, weak_type=True),
Array(2., dtype=float32, weak_type=True))
= jnp.array([2., 1.])
theta theta
Array([2., 1.], dtype=float32)
*theta) f(
Array(5., dtype=float32)
*theta)) jnp.array(grad_f(
Array([4., 2.], dtype=float32)
= 0.1
lr = theta- lr * jnp.array(grad_f(*theta))
theta theta
Array([1.6, 0.8], dtype=float32)
*theta) f(
Array(3.2000003, dtype=float32)
# Gradient descent loop
# Initial parameters
= jnp.array([2., 1.])
theta
# Store parameters and function values for plotting
= [theta]
theta_vals = [f(*theta)]
f_vals
for i in range(10):
= theta - lr * jnp.array(grad_f(*theta))
theta
theta_vals.append(theta)*theta))
f_vals.append(f(print(f'Iteration {i}: theta = {theta}, f = {f(*theta)}')
= jnp.array(theta_vals)
theta_vals = jnp.array(f_vals) f_vals
Iteration 0: theta = [1.6 0.8], f = 3.200000286102295
Iteration 1: theta = [1.28 0.64], f = 2.047999858856201
Iteration 2: theta = [1.0239999 0.51199996], f = 1.3107198476791382
Iteration 3: theta = [0.8191999 0.40959996], f = 0.8388606309890747
Iteration 4: theta = [0.6553599 0.32767996], f = 0.5368707776069641
Iteration 5: theta = [0.52428794 0.26214397], f = 0.34359729290008545
Iteration 6: theta = [0.41943035 0.20971517], f = 0.21990226209163666
Iteration 7: theta = [0.3355443 0.16777214], f = 0.14073745906352997
Iteration 8: theta = [0.26843542 0.13421771], f = 0.09007196873426437
Iteration 9: theta = [0.21474834 0.10737417], f = 0.05764605849981308
# Plot the cost vs iterations
plt.plot(f_vals)
# Simple dataset for linear regression
= jnp.array([[1.], [2.], [3.]])
X = jnp.array([1., 2.2, 2.8])
y
from sklearn.linear_model import LinearRegression
= LinearRegression()
lr lr.fit(X, y)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
lr.coef_, lr.intercept_
(array([0.9000001], dtype=float32), 0.19999981)
# Cost function for linear regression using jax.vmap
def cost(theta_0, theta_1):
= (theta_0 + theta_1 * X).flatten()
y_hat #print(y_hat, y, y-y_hat, (y-y_hat)**2)
return jnp.mean((y_hat- y)**2)
# Plot surface and contour plots for cost function
#create_plot(cost)
2.0, 2.0) cost(
Array(16.826666, dtype=float32)
3**2 + 3.8**2 + 5.2**2)/3. (
16.826666666666668
# Gradient of cost function at a given point
def grad_cost(theta_0, theta_1):
return jnp.array(grad(cost, argnums=(0, 1))(theta_0, theta_1))
2.0, 2.0) grad_cost(
Array([ 8. , 17.466667], dtype=float32)
def grad_cost_manual(theta_0, theta_1):
= (theta_0 + theta_1 * X).flatten()
y_hat return jnp.array([2*jnp.mean(y_hat - y), 2*jnp.mean((y_hat - y) * X.flatten())])
2.0, 2.0) grad_cost_manual(
Array([ 8. , 17.466667], dtype=float32)
# Plotting cost surface and contours for three points in X individually
def cost_i(theta_0, theta_1, i = 1):
= theta_0 + theta_1 * X[i-1:i]
y_hat return jnp.mean((y_hat- y[i-1:i])**2)
2.0, 2.0, 1) + cost_i(2.0, 2.0, 2) + cost_i(2.0, 2.0, 3))/3.0 (cost_i(
Array(16.826666, dtype=float32)
from functools import partial
# Plot surface and contour plots for cost function
for i in range(1, 4):
= partial(cost_i, i=i)
cost_i_p create_plot(cost_i_p)
Global minimum: 0.0 at [0.5 0.5]
Global minimum: 0.0 at [0.44000003 0.88000005]
Global minimum: 0.0 at [0.28000003 0.84 ]
= grad(cost_i, argnums=(0, 1))
grad_cost_1 2.0, 2.0) grad_cost_1(
(Array(6., dtype=float32, weak_type=True),
Array(6., dtype=float32, weak_type=True))
2.0, 2.0, 1)), jnp.array(grad_cost_1(2.0, 2.0, 2)), jnp.array(grad_cost_1(2.0, 2.0, 3)) jnp.array(grad_cost_1(
(Array([6., 6.], dtype=float32),
Array([ 7.6, 15.2], dtype=float32),
Array([10.4 , 31.199999], dtype=float32))