Maths and JAX

Maths for ML and JAX
Author

Nipun Batra

Published

January 31, 2023

import pandas as pd
import numpy as np
import jax.numpy as jnp
import jax
def func(x, y, z):
    return x**2 + jnp.sin(y) + z
func(1, 2, 3)
DeviceArray(4.9092975, dtype=float32, weak_type=True)
from sympy import *
init_printing()

x, y, z = symbols('x y z')
f = x**2 + sin(y) + z
f

\(\displaystyle x^{2} + z + \sin{\left(y \right)}\)

diff(f, x)

\(\displaystyle 2 x\)

# Find the derivative of f with respect to x, y, and z using sympy
del_x, del_y, del_z = diff(f, x), diff(f, y), diff(f, z)
del_x, del_y, del_z

\(\displaystyle \left( 2 x, \ \cos{\left(y \right)}, \ 1\right)\)

grad_f = lambdify((x, y, z), [del_x, del_y, del_z])
grad_f(1, 2, 3)

\(\displaystyle \left[ 2, \ -0.416146836547142, \ 1\right]\)

grad_f_jax = jax.grad(func, argnums=(0, 1, 2))
grad_f_jax(1., 2., 3.)
(DeviceArray(2., dtype=float32, weak_type=True),
 DeviceArray(-0.41614684, dtype=float32, weak_type=True),
 DeviceArray(1., dtype=float32, weak_type=True))
n = 20
A = jax.random.normal(shape=(1, n), key=jax.random.PRNGKey(0), dtype=jnp.float32)
theta = jax.random.normal(shape=(n, 1), key=jax.random.PRNGKey(0), dtype=jnp.float32)
b = A @ theta 

b
DeviceArray([[28.684494]], dtype=float32)
b.flatten(), b.item()
(DeviceArray([28.684494], dtype=float32), 28.684494018554688)
def a_theta(A, theta):
    return A @ theta
a_theta(A, theta)
DeviceArray([[28.684494]], dtype=float32)
grad_a_theta = jax.grad(a_theta, argnums=1)
jax.jacobian(a_theta, argnums=1)(A, theta)[0, 0, :].shape

\(\displaystyle \left( 20, \ 1\right)\)

A.shape

\(\displaystyle \left( 1, \ 20\right)\)

# Sympy version
A = MatrixSymbol('A', 1, n)
theta = MatrixSymbol('theta', n, 1)
A, theta

\(\displaystyle \left( A, \ \theta\right)\)

diff(A*theta, theta)

\(\displaystyle A^{T}\)