Taylor Series

Taylor Series
Author

Nipun Batra

Published

February 14, 2023

import jax.numpy as jnp
from jax import random, jit, vmap, grad, jacfwd, jacrev, hessian, value_and_grad
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# Define the function to be approximated

def f(x):
    return jnp.sin(x)
# Plot the function

x = jnp.linspace(-jnp.pi, jnp.pi, 100)
plt.plot(x, f(x))

# First order Taylor approximation for f(x) at x = 0

def taylor1(f, x, x0=0.):
    return f(x0) + grad(f)(x0) * (x - x0)
# Plot the Taylor approximation

plt.plot(x, f(x), label='f(x)')
plt.plot(x, taylor1(f, x), label='Taylor approximation')

# factorial function in JAX

def factorial(n):
    return jnp.prod(jnp.arange(1, n + 1))
# Find the nth order Taylor approximation for f(x) at x = 0

def taylor(f, x, n, x0=0.):
    grads = {0:f}
    output = f(x0)
    for i in range(1, n+1):
        grads[i] = grad(grads[i-1])
        output += grads[i](x0) * (x - x0)**i / factorial(i)
    return output
plt.plot(x, f(x), label='f(x)', lw=5)
plt.plot(x, taylor(f, x, 1), label='Taylor approximation, n=1')
plt.plot(x, taylor(f, x, 3), label='Taylor approximation, n=3')
plt.plot(x, taylor(f, x, 5), label='Taylor approximation, n=5')
plt.legend()
<matplotlib.legend.Legend at 0x1aea5ea90>

x = jnp.linspace(-4, 4, 100)

def g(x):
    return x**2

plt.plot(x, g(x), label='g(x)', lw=4, alpha=0.5)
plt.plot(x, taylor(g, x, 1), label='Taylor approximation, n=1')
plt.plot(x, taylor(g, x, 2), label='Taylor approximation, n=3', ls='--')

plt.plot(x, g(x), label='g(x)', lw=4, alpha=0.5)
plt.plot(x, taylor(g, x, 1, 4.1), label='Taylor approximation, n=1')
plt.plot(x, taylor(g, x, 2, 4.1), label='Taylor approximation, n=3', ls='--')
plt.ylim((-2, 20))
(-2.0, 20.0)