from jax import vmap, jit, grad, vmap
import jax.numpy as jnp
# Enable 64-bit mode
from jax.config import config
"jax_enable_x64", True)
config.update(import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
%matplotlib inline
'dark_background')
plt.style.use(
# Retina display
%config InlineBackend.figure_format = 'retina'
# Create an array of 8 points around a unit circle, each at 45 degrees distance
= jnp.array([[1, 0], [0.707, 0.707], [0, 1], [-0.707, 0.707], [-1, 0], [-0.707, -0.707], [0, -1], [0.707, -0.707]])
x
# Define the matrix A
= jnp.array([[3, 0], [4, 5]])
A
# Compute the SVD of A
= jnp.linalg.svd(A, full_matrices=False)
U, S, VT = VT.T V
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
def matvec(A, vec):
"""
vec: (2,)
A: (2, 2)
"""
return A@vec
0]), matvec(A, x[1]) matvec(A, x[
(DeviceArray([3., 4.], dtype=float64),
DeviceArray([2.121, 6.363], dtype=float64))
# vmap the matvec function
= vmap(matvec, in_axes=(None, 0))
vmap_matvec = vmap_matvec(A, x)
Ax Ax
DeviceArray([[ 3. , 4. ],
[ 2.121, 6.363],
[ 0. , 5. ],
[-2.121, 0.707],
[-3. , -4. ],
[-2.121, -6.363],
[ 0. , -5. ],
[ 2.121, -0.707]], dtype=float64)
= vmap_matvec(VT, x)
Vx Vx
DeviceArray([[-7.07106781e-01, -7.07106781e-01],
[-9.99848989e-01, 6.74518966e-18],
[-7.07106781e-01, 7.07106781e-01],
[-6.74518966e-18, 9.99848989e-01],
[ 7.07106781e-01, 7.07106781e-01],
[ 9.99848989e-01, -6.74518966e-18],
[ 7.07106781e-01, -7.07106781e-01],
[ 6.74518966e-18, -9.99848989e-01]], dtype=float64)
= Vx*S
SV SV
DeviceArray([[-4.74341649e+00, -1.58113883e+00],
[-6.70719092e+00, 1.50827026e-17],
[-4.74341649e+00, 1.58113883e+00],
[-4.52481078e-17, 2.23573031e+00],
[ 4.74341649e+00, 1.58113883e+00],
[ 6.70719092e+00, -1.50827026e-17],
[ 4.74341649e+00, -1.58113883e+00],
[ 4.52481078e-17, -2.23573031e+00]], dtype=float64)
vmap_matvec(U, SV)
DeviceArray([[ 3.00000000e+00, 4.00000000e+00],
[ 2.12100000e+00, 6.36300000e+00],
[ 9.53863757e-16, 5.00000000e+00],
[-2.12100000e+00, 7.07000000e-01],
[-3.00000000e+00, -4.00000000e+00],
[-2.12100000e+00, -6.36300000e+00],
[-9.53863757e-16, -5.00000000e+00],
[ 2.12100000e+00, -7.07000000e-01]], dtype=float64)
jnp.allclose(vmap_matvec(U, SV), vmap_matvec(A, x))
DeviceArray(True, dtype=bool)
vmap_matvec(U, SV)
DeviceArray([[ 3.00000000e+00, 4.00000000e+00],
[ 2.12100000e+00, 6.36300000e+00],
[ 9.53863757e-16, 5.00000000e+00],
[-2.12100000e+00, 7.07000000e-01],
[-3.00000000e+00, -4.00000000e+00],
[-2.12100000e+00, -6.36300000e+00],
[-9.53863757e-16, -5.00000000e+00],
[ 2.12100000e+00, -7.07000000e-01]], dtype=float64)
vmap_matvec(A, x)
DeviceArray([[ 3. , 4. ],
[ 2.121, 6.363],
[ 0. , 5. ],
[-2.121, 0.707],
[-3. , -4. ],
[-2.121, -6.363],
[ 0. , -5. ],
[ 2.121, -0.707]], dtype=float64)
# Plot subplots of the above transformations in one figure with arrows showing the direction of transformation.
# We have 2 rows and 2 columns
= 8
fs = plt.subplots(2, 2, figsize=(8, 8))
fig, ax
# Modify the plot function to accept an axis object
def plot(x, ax):
0.0, 0.0, 0.0, 1))
ax.set_facecolor((0], x[:, 1], c=jnp.arange(x.shape[0]), cmap='viridis', s=100)
ax.scatter(x[:, # Add the index of the point as a label
for i in range(x.shape[0]):
0], x[i, 1], str(i), color='white', fontsize=8, ha='center', va='center')
ax.text(x[i, -fs, fs)
ax.set_xlim(-fs, fs)
ax.set_ylim('equal')
ax.set_aspect(
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])'')
ax.set_xlabel('')
ax.set_ylabel(= -1
ax.zorder
# Disable the border axis
'top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines[
# Plot the original points
0, 0])
plot(x, ax[
# Plot transformed points by A using JAX.vmap
= vmap_matvec(A, x)
xa 0, 1])
plot(xa, ax[
# Plot rotated points by VT
= vmap_matvec(VT, x)
xv
# Find the angle of rotation by V
= jnp.arctan2(VT[1, 0], VT[0, 0])
angle_v print('Angle of rotation by V: {:.2f} degrees'.format(angle_v*180/jnp.pi))
1, 0])
plot(xv, ax[
= jnp.arctan2(U[1, 0], U[0, 0])
angle_u
# plot the above points scaled by S
= xv*S
xs 1, 1])
plot(xs, ax[
# Add an arrow between [0, 0] and [0, 1] subplots using the matplotlib.patch.ConnectionPatch
# add some text "transformed via A" to the arrow
= ConnectionPatch(xyA=(fs, 0), xyB=(-fs, 0), coordsA="data", coordsB="data", axesA=ax[0, 0], axesB=ax[0, 1],
con ="w", zorder=1, arrowstyle='->', lw=2)
color0, 0].add_artist(con)
ax[0, 0].text(2*fs, 2, 'Transformed via A', color='w', fontsize=10, ha='center', va='center', zorder=1)
ax[
= ConnectionPatch(xyA=(0, -fs), xyB=(0, fs), coordsA="data", coordsB="data", axesA=ax[0, 0], axesB=ax[1, 0],
con1 ="w", zorder=1, arrowstyle='->', lw=2)
color0, 0].add_artist(con1)
ax[0, 0].text(2, -2*fs, r'$V^T$' +f' (Rotation by {angle_v*180/jnp.pi:0.2f})', color='w', fontsize=10, ha='center', va='center', zorder=1, rotation=90)
ax[
= ConnectionPatch(xyA=(fs, 0), xyB=(-fs, 0), coordsA="data", coordsB="data", axesA=ax[1, 0], axesB=ax[1, 1],
con2 ="w", zorder=1, arrowstyle='->', lw=2)
color1, 0].add_artist(con2)
ax[1, 0].text(2*fs, 2, f'Scaled by S\n Horizontally by {S[0]:0.2f}\n Vertically by {S[1]:0.2f}', color='w', fontsize=10, ha='center', va='center', zorder=1)
ax[
= ConnectionPatch(xyA=(0, fs), xyB=(0, -fs), coordsA="data", coordsB="data", axesA=ax[1, 1], axesB=ax[0, 1],
con3 ="w", zorder=1, arrowstyle='->', lw=2)
color1, 1].add_artist(con3)
ax[1, 1].text(2, 2*fs, f'U (Rotation by {angle_u*180/jnp.pi:0.2f})', color='w', fontsize=10, ha='center', va='center', zorder=1, rotation=90)
ax[
# Add a lot of spacing between the subplots
=0.5, wspace=0.5)
fig.subplots_adjust(hspace
'Singular Value Decomposition\n' +r'$A = USV^T$', fontsize=18, color='w')
fig.suptitle(
fig.tight_layout()'svd.png', dpi=600, bbox_inches='tight') fig.savefig(
Angle of rotation by V: -135.00 degrees