import pandas as pd
import numpy as np
import jax.numpy as jnp
import jax
Maths and JAX
Maths for ML and JAX
def func(x, y, z):
return x**2 + jnp.sin(y) + z
1, 2, 3) func(
DeviceArray(4.9092975, dtype=float32, weak_type=True)
from sympy import *
init_printing()
= symbols('x y z')
x, y, z = x**2 + sin(y) + z
f f
\(\displaystyle x^{2} + z + \sin{\left(y \right)}\)
diff(f, x)
\(\displaystyle 2 x\)
# Find the derivative of f with respect to x, y, and z using sympy
= diff(f, x), diff(f, y), diff(f, z)
del_x, del_y, del_z del_x, del_y, del_z
\(\displaystyle \left( 2 x, \ \cos{\left(y \right)}, \ 1\right)\)
= lambdify((x, y, z), [del_x, del_y, del_z])
grad_f 1, 2, 3) grad_f(
\(\displaystyle \left[ 2, \ -0.416146836547142, \ 1\right]\)
= jax.grad(func, argnums=(0, 1, 2))
grad_f_jax 1., 2., 3.) grad_f_jax(
(DeviceArray(2., dtype=float32, weak_type=True),
DeviceArray(-0.41614684, dtype=float32, weak_type=True),
DeviceArray(1., dtype=float32, weak_type=True))
= 20
n = jax.random.normal(shape=(1, n), key=jax.random.PRNGKey(0), dtype=jnp.float32)
A = jax.random.normal(shape=(n, 1), key=jax.random.PRNGKey(0), dtype=jnp.float32)
theta = A @ theta
b
b
DeviceArray([[28.684494]], dtype=float32)
b.flatten(), b.item()
(DeviceArray([28.684494], dtype=float32), 28.684494018554688)
def a_theta(A, theta):
return A @ theta
a_theta(A, theta)
DeviceArray([[28.684494]], dtype=float32)
= jax.grad(a_theta, argnums=1) grad_a_theta
=1)(A, theta)[0, 0, :].shape jax.jacobian(a_theta, argnums
\(\displaystyle \left( 20, \ 1\right)\)
A.shape
\(\displaystyle \left( 1, \ 20\right)\)
# Sympy version
= MatrixSymbol('A', 1, n)
A = MatrixSymbol('theta', n, 1)
theta A, theta
\(\displaystyle \left( A, \ \theta\right)\)
*theta, theta) diff(A
\(\displaystyle A^{T}\)