# Plot subplots of the above transformations in one figure with arrows showing the direction of transformation.
# We have 2 rows and 2 columns
fs = 8
fig, ax = plt.subplots(2, 2, figsize=(8, 8))
# Modify the plot function to accept an axis object
def plot(x, ax):
ax.set_facecolor((0.0, 0.0, 0.0, 1))
ax.scatter(x[:, 0], x[:, 1], c=jnp.arange(x.shape[0]), cmap='viridis', s=100)
# Add the index of the point as a label
for i in range(x.shape[0]):
ax.text(x[i, 0], x[i, 1], str(i), color='white', fontsize=8, ha='center', va='center')
ax.set_xlim(-fs, fs)
ax.set_ylim(-fs, fs)
ax.set_aspect('equal')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_xlabel('')
ax.set_ylabel('')
ax.zorder = -1
# Disable the border axis
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
# Plot the original points
plot(x, ax[0, 0])
# Plot transformed points by A using JAX.vmap
xa = vmap_matvec(A, x)
plot(xa, ax[0, 1])
# Plot rotated points by VT
xv = vmap_matvec(VT, x)
# Find the angle of rotation by V
angle_v = jnp.arctan2(VT[1, 0], VT[0, 0])
print('Angle of rotation by V: {:.2f} degrees'.format(angle_v*180/jnp.pi))
plot(xv, ax[1, 0])
angle_u = jnp.arctan2(U[1, 0], U[0, 0])
# plot the above points scaled by S
xs = xv*S
plot(xs, ax[1, 1])
# Add an arrow between [0, 0] and [0, 1] subplots using the matplotlib.patch.ConnectionPatch
# add some text "transformed via A" to the arrow
con = ConnectionPatch(xyA=(fs, 0), xyB=(-fs, 0), coordsA="data", coordsB="data", axesA=ax[0, 0], axesB=ax[0, 1],
color="w", zorder=1, arrowstyle='->', lw=2)
ax[0, 0].add_artist(con)
ax[0, 0].text(2*fs, 2, 'Transformed via A', color='w', fontsize=10, ha='center', va='center', zorder=1)
con1 = ConnectionPatch(xyA=(0, -fs), xyB=(0, fs), coordsA="data", coordsB="data", axesA=ax[0, 0], axesB=ax[1, 0],
color="w", zorder=1, arrowstyle='->', lw=2)
ax[0, 0].add_artist(con1)
ax[0, 0].text(2, -2*fs, r'$V^T$' +f' (Rotation by {angle_v*180/jnp.pi:0.2f})', color='w', fontsize=10, ha='center', va='center', zorder=1, rotation=90)
con2 = ConnectionPatch(xyA=(fs, 0), xyB=(-fs, 0), coordsA="data", coordsB="data", axesA=ax[1, 0], axesB=ax[1, 1],
color="w", zorder=1, arrowstyle='->', lw=2)
ax[1, 0].add_artist(con2)
ax[1, 0].text(2*fs, 2, f'Scaled by S\n Horizontally by {S[0]:0.2f}\n Vertically by {S[1]:0.2f}', color='w', fontsize=10, ha='center', va='center', zorder=1)
con3 = ConnectionPatch(xyA=(0, fs), xyB=(0, -fs), coordsA="data", coordsB="data", axesA=ax[1, 1], axesB=ax[0, 1],
color="w", zorder=1, arrowstyle='->', lw=2)
ax[1, 1].add_artist(con3)
ax[1, 1].text(2, 2*fs, f'U (Rotation by {angle_u*180/jnp.pi:0.2f})', color='w', fontsize=10, ha='center', va='center', zorder=1, rotation=90)
# Add a lot of spacing between the subplots
fig.subplots_adjust(hspace=0.5, wspace=0.5)
fig.suptitle('Singular Value Decomposition\n' +r'$A = USV^T$', fontsize=18, color='w')
fig.tight_layout()
fig.savefig('svd.png', dpi=600, bbox_inches='tight')