import jax.numpy as jnp
from jax import random, jit, vmap, grad, jacfwd, jacrev, hessian, value_and_grad
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Taylor Series
Taylor Series
# Define the function to be approximated
def f(x):
return jnp.sin(x)
# Plot the function
= jnp.linspace(-jnp.pi, jnp.pi, 100)
x plt.plot(x, f(x))
# First order Taylor approximation for f(x) at x = 0
def taylor1(f, x, x0=0.):
return f(x0) + grad(f)(x0) * (x - x0)
# Plot the Taylor approximation
='f(x)')
plt.plot(x, f(x), label='Taylor approximation') plt.plot(x, taylor1(f, x), label
# factorial function in JAX
def factorial(n):
return jnp.prod(jnp.arange(1, n + 1))
# Find the nth order Taylor approximation for f(x) at x = 0
def taylor(f, x, n, x0=0.):
= {0:f}
grads = f(x0)
output for i in range(1, n+1):
= grad(grads[i-1])
grads[i] += grads[i](x0) * (x - x0)**i / factorial(i)
output return output
='f(x)', lw=5)
plt.plot(x, f(x), label1), label='Taylor approximation, n=1')
plt.plot(x, taylor(f, x, 3), label='Taylor approximation, n=3')
plt.plot(x, taylor(f, x, 5), label='Taylor approximation, n=5')
plt.plot(x, taylor(f, x, plt.legend()
<matplotlib.legend.Legend at 0x1aea5ea90>
= jnp.linspace(-4, 4, 100)
x
def g(x):
return x**2
='g(x)', lw=4, alpha=0.5)
plt.plot(x, g(x), label1), label='Taylor approximation, n=1')
plt.plot(x, taylor(g, x, 2), label='Taylor approximation, n=3', ls='--') plt.plot(x, taylor(g, x,
='g(x)', lw=4, alpha=0.5)
plt.plot(x, g(x), label1, 4.1), label='Taylor approximation, n=1')
plt.plot(x, taylor(g, x, 2, 4.1), label='Taylor approximation, n=3', ls='--')
plt.plot(x, taylor(g, x, -2, 20)) plt.ylim((
(-2.0, 20.0)