import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Linear Regression: Geometric Perspective
Linear Regression Geometric Perspective
# Two three-dimensional vectors
= np.array([1, 1, 1])
v1 = np.array([2, -2, 2])
v2
# y-vector
= np.array([2.5, -0.8, 1.2]) y
# plot the vectors in 3D
= plt.figure(figsize=(8, 8))
fig = fig.add_subplot(111, projection='3d')
ax 0, 0, 0, v1[0], v1[1], v1[2], color='r', label='v1')
ax.quiver(0, 0, 0, v2[0], v2[1], v2[2], color='b', label='v2')
ax.quiver(0, 0, 0, y[0], y[1], y[2], color='g', label='y')
ax.quiver(
0, 3)
ax.set_xlim(0, 4)
ax.set_ylim(0, 3)
ax.set_zlim('x')
ax.set_xlabel('y')
ax.set_ylabel('z')
ax.set_zlabel(
ax.legend()
=45, azim=60) ax.view_init(elev
= np.linalg.lstsq(np.array([v1, v2]).T, y, rcond=None)[0]
theta theta
array([0.525 , 0.6625])
# Projection of y onto the plane spanned by v1 and v2
= np.dot(np.array([v1, v2]).T, theta)
y_proj y_proj
array([ 1.85, -0.8 , 1.85])
# Plot the x=z plane filled with color black
= plt.subplots(figsize=(8, 8))
fig, ax # 3d projection
= fig.add_subplot(111, projection='3d')
ax = np.meshgrid(np.linspace(-1, 4, 100), np.linspace(-1, 4, 100))
xx, zz = np.zeros_like(xx)
yy =0.2, color='k')
ax.plot_surface(xx, yy, zz, alpha
# plot the vectors in 3D
0, 0, 0, v1[0], v1[1], v1[2], color='r', label='v1')
ax.quiver(0, 0, 0, v2[0], v2[1], v2[2], color='b', label='v2')
ax.quiver(0, 0, 0, y[0], y[1], y[2], color='g', label='y')
ax.quiver(
# Limit the view to the x-z plane
'x')
ax.set_xlabel('y')
ax.set_ylabel('z')
ax.set_zlabel(
# Set the view angle so that we can see the x-z plane appear at a 45 degree angle
#ax.view_init(azim=70, elev=200)
=60, azim=-80, roll=90)
ax.view_init(elev120, -120, -120)
ax.view_init(#ax.view_init(roll=45)
#ax.view_init(elev=30, azim=45, roll=15)
-4, 4)
ax.set_ylim(0, 4)
ax.set_xlim(0, 4)
ax.set_zlim(
# Plot the projection of y onto the plane spanned by v1 and v2
0, 0, 0, y_proj[0], y_proj[1], y_proj[2], color='k', label='Projection of y onto\n the plane spanned by v1 and v2')
ax.quiver( plt.legend()
<matplotlib.legend.Legend at 0x111d52730>