import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# Use 64 bit precision for JAX
"jax_enable_x64", True) jax.config.update(
# Create a 2d function
def f(x):
return jnp.sin(x[0]*x[1])
=(-2, 2)
x_range=(-2, 2)
y_range=100
n= jnp.linspace(x_range[0], x_range[1], n)
x = jnp.linspace(y_range[0], y_range[1], n)
y = jnp.meshgrid(x, y) X, Y
# Evaluate the function at a grid of points using vmap
def eval_grid(f, x_range=(-2, 2), y_range=(-2, 2), n=100):
= jnp.linspace(x_range[0], x_range[1], n)
x = jnp.linspace(y_range[0], y_range[1], n)
y = jnp.meshgrid(x, y)
X, Y return X, Y, jax.vmap(jax.vmap(f, in_axes=0), in_axes=0)(jnp.stack([X, Y], axis=-1))
# Plot the contour of the function
def plot_contour(f, x_range=(-2, 2), y_range=(-2, 2), n=100, ax = None, **kwargs):
= eval_grid(f, x_range, y_range, n)
X, Y, Z if ax is None:
= plt.subplots()
fig, ax
= jnp.linspace(-1.0, int(jnp.max(Z))+0.5, 11)
levels = ax.contour(X, Y,Z, levels=levels, **kwargs)
contours =True, fontsize=8)
ax.clabel(contours, inline
#ax.imshow(Z, extent= [X.min(), X.max(), Y.min(), Y.max()], origin='lower', cmap='viridis', alpha=0.5)
return ax
='k') plot_contour(f, colors
# Plot surface of the function
def plot_surface(f, x_range=(-2, 2), y_range=(-2, 2), n=100, ax = None, **kwargs):
= eval_grid(f, x_range, y_range, n)
X, Y, Z
if ax is None:
= plt.subplots(subplot_kw={"projection": "3d"})
fig, ax
**kwargs)
ax.plot_surface(X, Y, Z, return ax
# Plot surface in Plotly
import plotly.graph_objects as go
def plot_surface_plotly(f, x_range=(-2, 2), y_range=(-2, 2), n=100, **kwargs):
= eval_grid(f, x_range, y_range, n)
X, Y, Z = go.Figure(data=[go.Surface(z=Z, x=X, y=Y, **kwargs)])
fig = dict(
fig.update_layout(scene ='x',
xaxis_title='y',
yaxis_title='z'))
zaxis_title
fig.show()
='viridis') plot_surface(f, cmap
plot_surface_plotly(f)
Unable to display output for mime type(s): application/vnd.plotly.v1+json
= jax.grad(f)
g = jax.hessian(f) H
1.0, 1.0])) jnp.array(g([
Array([0.54030231, 0.54030231], dtype=float64)
1.0, 1.0])) jnp.array(H([
Array([[-0.84147098, -0.30116868],
[-0.30116868, -0.84147098]], dtype=float64)
print(type(f([1.0, 1.0])), type(jnp.array(g([1.0, 1.0]))), type(H([1.0, 1.0])))
<class 'jaxlib.xla_extension.Array'> <class 'jaxlib.xla_extension.Array'> <class 'list'>
# First order Taylor approximation around x0
def taylor1(f, x0):
= jax.grad(f)
g = lambda x: f(x0) + jnp.array(g(x0)) @ (x - x0)
t # Print the Taylor approximation
print("f(x) = {:.2f} + {:.2f} (x1 - {:.2f}) + {:.2f} (x2 - {:.2f})".format(f(x0), g(x0)[0], x0[0], g(x0)[1], x0[1]))
return t
1.0, 1.0]))(jnp.array([0.0, 1.0])) taylor1(f, jnp.array([
f(x) = 0.84 + 0.54 (x1 - 1.00) + 0.54 (x2 - 1.00)
Array(0.30116868, dtype=float64)
# Second order Taylor approximation around x0
def taylor2(f, x0):
= jax.grad(f)
g = jax.hessian(f)
H = lambda x: f(x0) + jnp.array(g(x0)) @ (x - x0) + 0.5*(x - x0) @ jnp.array(H(x0)) @ (x - x0)
t # Print the Taylor approximation
print("f(x) = {:.2f} + {:.2f} (x1 - {:.2f}) + {:.2f} (x2 - {:.2f}) + {:.2f} (x1 - {:.2f})^2 + {:.2f} (x2 - {:.2f})^2 + {:.2f} (x1 - {:.2f})(x2 - {:.2f})".format(f(x0), g(x0)[0], x0[0], g(x0)[1], x0[1], H(x0)[0, 0], x0[0], H(x0)[1, 1], x0[1], H(x0)[0, 1], x0[0], x0[1]))
return t
1.0, 1.0]))(jnp.array([0.0, 1.0])) taylor2(f, jnp.array([
f(x) = 0.84 + 0.54 (x1 - 1.00) + 0.54 (x2 - 1.00) + -0.84 (x1 - 1.00)^2 + -0.84 (x2 - 1.00)^2 + -0.30 (x1 - 1.00)(x2 - 1.00)
Array(-0.11956681, dtype=float64)
1.0, 1.0])) H(jnp.array([
Array([[-0.84147098, -0.30116868],
[-0.30116868, -0.84147098]], dtype=float64)
1.0, 1.0])) g(jnp.array([
Array([0.54030231, 0.54030231], dtype=float64)
# Plot contour of the Taylor approximation around x0 for both first and second order in comparison with the original function
# 3 subplots
def plot_taylor(f, x0, x_range=(-2, 2), y_range=(-2, 2), n=100, ax = None):
= taylor1(f, x0)
t1 = taylor2(f, x0)
t2
if ax is None:
= plt.subplots(1, 3, figsize=(15, 5))
fig, ax
# Mark the point x0
0].scatter(x0[0], x0[1], marker='x', color='red', s=100)
ax[
# Plot the contour of the function
=ax[0], colors='black')
plot_contour(f, x_range, y_range, n, ax
# Plot the contour of the first order Taylor approximation
=ax[1], colors='black')
plot_contour(t1, x_range, y_range, n, ax1].scatter(x0[0], x0[1], marker='x', color='red', s=100)
ax[
# Plot the contour of the second order Taylor approximation
=ax[2], colors='black')
plot_contour(t2, x_range, y_range, n, ax2].scatter(x0[0], x0[1], marker='x', color='red', s=100) ax[
1.0, 1.0])) plot_taylor(f, jnp.array([
f(x) = 0.84 + 0.54 (x1 - 1.00) + 0.54 (x2 - 1.00)
f(x) = 0.84 + 0.54 (x1 - 1.00) + 0.54 (x2 - 1.00) + -0.84 (x1 - 1.00)^2 + -0.84 (x2 - 1.00)^2 + -0.30 (x1 - 1.00)(x2 - 1.00)
# Plot surface of the Taylor approximation around x0 for both first and second order in comparison with the original function
# 3 subplots
def plot_taylor_surface(f, x0, x_range=(-2, 2), y_range=(-2, 2), n=100, ax = None):
= taylor1(f, x0)
t1 = taylor2(f, x0)
t2
if ax is None:
= plt.subplots(1, 3, figsize=(15, 5), subplot_kw={"projection": "3d"})
fig, ax
# Mark the point x0
0].scatter(x0[0], x0[1], f(x0), marker='x', color='red', s=100)
ax[
# Plot the surface of the function
=ax[0])
plot_surface(f, x_range, y_range, n, ax
# Plot the surface of the first order Taylor approximation
=ax[1])
plot_surface(t1, x_range, y_range, n, ax1].scatter(x0[0], x0[1], f(x0), marker='x', color='red', s=100)
ax[
# Plot the surface of the second order Taylor approximation
=ax[2])
plot_surface(t2, x_range, y_range, n, ax2].scatter(x0[0], x0[1], f(x0), marker='x', color='red', s=100) ax[
1.0, 1.0])) plot_taylor_surface(f, jnp.array([
f(x) = 0.84 + 0.54 (x1 - 1.00) + 0.54 (x2 - 1.00)
f(x) = 0.84 + 0.54 (x1 - 1.00) + 0.54 (x2 - 1.00) + -0.84 (x1 - 1.00)^2 + -0.84 (x2 - 1.00)^2 + -0.30 (x1 - 1.00)(x2 - 1.00)
Second order Taylor series expansion of a function f around a point (x0, y0) is given by (when using the vector notation)
\[f(x,y) = f(x_0,y_0) + \frac{\partial f}{\partial x}(x_0,y_0)(x-x_0) + \frac{\partial f}{\partial y}(x_0,y_0)(y-y_0) + \frac{1}{2} \frac{\partial^2 f}{\partial x^2}(x_0,y_0)(x-x_0)^2 + \frac{1}{2} \frac{\partial^2 f}{\partial y^2}(x_0,y_0)(y-y_0)^2 + \frac{1}{2} \frac{\partial^2 f}{\partial x \partial y}(x_0,y_0)(x-x_0)(y-y_0)\]