import pandas as pd
import numpy as np
import jax.numpy as jnp
import jaxMaths and JAX
Maths for ML and JAX
def func(x, y, z):
return x**2 + jnp.sin(y) + zfunc(1, 2, 3)DeviceArray(4.9092975, dtype=float32, weak_type=True)
from sympy import *
init_printing()
x, y, z = symbols('x y z')
f = x**2 + sin(y) + z
f\(\displaystyle x^{2} + z + \sin{\left(y \right)}\)
diff(f, x)\(\displaystyle 2 x\)
# Find the derivative of f with respect to x, y, and z using sympy
del_x, del_y, del_z = diff(f, x), diff(f, y), diff(f, z)
del_x, del_y, del_z\(\displaystyle \left( 2 x, \ \cos{\left(y \right)}, \ 1\right)\)
grad_f = lambdify((x, y, z), [del_x, del_y, del_z])
grad_f(1, 2, 3)\(\displaystyle \left[ 2, \ -0.416146836547142, \ 1\right]\)
grad_f_jax = jax.grad(func, argnums=(0, 1, 2))
grad_f_jax(1., 2., 3.)(DeviceArray(2., dtype=float32, weak_type=True),
DeviceArray(-0.41614684, dtype=float32, weak_type=True),
DeviceArray(1., dtype=float32, weak_type=True))
n = 20
A = jax.random.normal(shape=(1, n), key=jax.random.PRNGKey(0), dtype=jnp.float32)
theta = jax.random.normal(shape=(n, 1), key=jax.random.PRNGKey(0), dtype=jnp.float32)
b = A @ theta
bDeviceArray([[28.684494]], dtype=float32)
b.flatten(), b.item()(DeviceArray([28.684494], dtype=float32), 28.684494018554688)
def a_theta(A, theta):
return A @ thetaa_theta(A, theta)DeviceArray([[28.684494]], dtype=float32)
grad_a_theta = jax.grad(a_theta, argnums=1)jax.jacobian(a_theta, argnums=1)(A, theta)[0, 0, :].shape\(\displaystyle \left( 20, \ 1\right)\)
A.shape\(\displaystyle \left( 1, \ 20\right)\)
# Sympy versionA = MatrixSymbol('A', 1, n)
theta = MatrixSymbol('theta', n, 1)
A, theta\(\displaystyle \left( A, \ \theta\right)\)
diff(A*theta, theta)\(\displaystyle A^{T}\)