import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Movie Recommendation using KNN and Matrix Factorization
Movie Recommendation using KNN and Matrix Factorization
# Generate some toy user and movie data
# Number of users
= 100
n_users
# Number of movies
= 10
n_movies
# Number of ratings
= 1000
n_ratings
# Generate random user ids
= np.random.randint(0, n_users, n_ratings)
user_ids
# Generate random movie ids
= np.random.randint(0, n_movies, n_ratings)
movie_ids
# Generate random ratings
= np.random.randint(1, 6, n_ratings)
ratings
# Create a dataframe with the data
= pd.DataFrame({'user_id': user_ids, 'movie_id': movie_ids, 'rating': ratings})
df
# We should not have any duplicate ratings for the same user and movie
# Drop any rows that have duplicate user_id and movie_id pairs
= df.drop_duplicates(['user_id', 'movie_id']) df
df
user_id | movie_id | rating | |
---|---|---|---|
0 | 66 | 7 | 2 |
1 | 54 | 7 | 4 |
2 | 38 | 5 | 3 |
3 | 56 | 6 | 1 |
4 | 4 | 0 | 4 |
... | ... | ... | ... |
987 | 77 | 8 | 3 |
992 | 99 | 3 | 3 |
994 | 8 | 5 | 3 |
998 | 22 | 2 | 3 |
999 | 88 | 9 | 1 |
642 rows × 3 columns
# Create a user-item matrix
= df.pivot(index='user_id', columns='movie_id', values='rating')
A A
movie_id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
user_id | ||||||||||
0 | 3.0 | 4.0 | NaN | 5.0 | 5.0 | 1.0 | 2.0 | 5.0 | 2.0 | 4.0 |
1 | 4.0 | 2.0 | NaN | 1.0 | NaN | 3.0 | 3.0 | 5.0 | 3.0 | 1.0 |
2 | 3.0 | NaN | NaN | 4.0 | 1.0 | 5.0 | NaN | 2.0 | NaN | 2.0 |
3 | 4.0 | NaN | 4.0 | 2.0 | NaN | 3.0 | NaN | 2.0 | NaN | 4.0 |
4 | 4.0 | 3.0 | 3.0 | 3.0 | NaN | NaN | 4.0 | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
95 | 1.0 | 3.0 | 1.0 | 2.0 | 3.0 | NaN | 1.0 | NaN | 5.0 | NaN |
96 | 1.0 | 2.0 | NaN | NaN | NaN | NaN | 3.0 | 2.0 | 2.0 | 5.0 |
97 | 3.0 | 1.0 | 4.0 | NaN | 3.0 | 1.0 | NaN | NaN | 5.0 | NaN |
98 | 5.0 | 3.0 | NaN | NaN | 2.0 | NaN | 1.0 | 3.0 | 4.0 | 3.0 |
99 | 1.0 | 5.0 | 2.0 | 3.0 | NaN | 2.0 | 4.0 | 3.0 | 3.0 | NaN |
100 rows × 10 columns
# Fill in the missing values with zeros
= A.fillna(0)
A
A
movie_id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
user_id | ||||||||||
0 | 3.0 | 4.0 | 0.0 | 5.0 | 5.0 | 1.0 | 2.0 | 5.0 | 2.0 | 4.0 |
1 | 4.0 | 2.0 | 0.0 | 1.0 | 0.0 | 3.0 | 3.0 | 5.0 | 3.0 | 1.0 |
2 | 3.0 | 0.0 | 0.0 | 4.0 | 1.0 | 5.0 | 0.0 | 2.0 | 0.0 | 2.0 |
3 | 4.0 | 0.0 | 4.0 | 2.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 4.0 |
4 | 4.0 | 3.0 | 3.0 | 3.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
95 | 1.0 | 3.0 | 1.0 | 2.0 | 3.0 | 0.0 | 1.0 | 0.0 | 5.0 | 0.0 |
96 | 1.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 2.0 | 2.0 | 5.0 |
97 | 3.0 | 1.0 | 4.0 | 0.0 | 3.0 | 1.0 | 0.0 | 0.0 | 5.0 | 0.0 |
98 | 5.0 | 3.0 | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 3.0 | 4.0 | 3.0 |
99 | 1.0 | 5.0 | 2.0 | 3.0 | 0.0 | 2.0 | 4.0 | 3.0 | 3.0 | 0.0 |
100 rows × 10 columns
# Cosine similarity between U1 and U2
# User 1
= A.loc[0]
u1
# User 2
= A.loc[1]
u2
# Compute the dot product
= np.dot(u1, u2)
dot
# Compute the L2 norm
= np.linalg.norm(u1)
norm_u1 = np.linalg.norm(u2)
norm_u2
# Compute the cosine similarity
= dot / (norm_u1 * norm_u2)
cos_sim cos_sim
0.7174278379758501
# Calculate the cosine similarity between users
from sklearn.metrics.pairwise import cosine_similarity
= cosine_similarity(A)
sim_matrix
pd.DataFrame(sim_matrix)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.000000 | 0.717428 | 0.663734 | 0.565794 | 0.547289 | 0.742468 | 0.539472 | 0.599021 | 0.748964 | 0.523832 | ... | 0.876593 | 0.593722 | 0.615664 | 0.722496 | 0.783190 | 0.657754 | 0.665375 | 0.446627 | 0.774667 | 0.703313 |
1 | 0.717428 | 1.000000 | 0.650769 | 0.591169 | 0.559964 | 0.585314 | 0.300491 | 0.255039 | 0.693395 | 0.500193 | ... | 0.609392 | 0.313893 | 0.510454 | 0.550459 | 0.624622 | 0.493197 | 0.644346 | 0.476288 | 0.802740 | 0.781611 |
2 | 0.663734 | 0.650769 | 1.000000 | 0.758954 | 0.406780 | 0.372046 | 0.416654 | 0.270593 | 0.746685 | 0.560180 | ... | 0.519274 | 0.322243 | 0.772529 | 0.253842 | 0.712485 | 0.257761 | 0.322830 | 0.283373 | 0.441886 | 0.459929 |
3 | 0.565794 | 0.591169 | 0.758954 | 1.000000 | 0.549030 | 0.540128 | 0.671775 | 0.572892 | 0.611794 | 0.489225 | ... | 0.424052 | 0.586110 | 0.456327 | 0.506719 | 0.715831 | 0.210494 | 0.506585 | 0.492312 | 0.551652 | 0.424052 |
4 | 0.547289 | 0.559964 | 0.406780 | 0.549030 | 1.000000 | 0.354329 | 0.304478 | 0.541185 | 0.821353 | 0.202287 | ... | 0.667638 | 0.527306 | 0.339913 | 0.435159 | 0.582943 | 0.478699 | 0.417780 | 0.450063 | 0.502836 | 0.741820 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
95 | 0.657754 | 0.493197 | 0.257761 | 0.210494 | 0.478699 | 0.519615 | 0.452602 | 0.571548 | 0.551553 | 0.405674 | ... | 0.564076 | 0.556890 | 0.604211 | 0.682793 | 0.661382 | 1.000000 | 0.412568 | 0.796715 | 0.678637 | 0.693008 |
96 | 0.665375 | 0.644346 | 0.322830 | 0.506585 | 0.417780 | 0.853538 | 0.359095 | 0.538977 | 0.351369 | 0.191776 | ... | 0.498686 | 0.640033 | 0.190421 | 0.690704 | 0.682163 | 0.412568 | 1.000000 | 0.280141 | 0.734105 | 0.581800 |
97 | 0.446627 | 0.476288 | 0.283373 | 0.492312 | 0.450063 | 0.400743 | 0.709211 | 0.532239 | 0.469979 | 0.566223 | ... | 0.262641 | 0.446564 | 0.501441 | 0.689500 | 0.637007 | 0.796715 | 0.280141 | 1.000000 | 0.659366 | 0.481508 |
98 | 0.774667 | 0.802740 | 0.441886 | 0.551652 | 0.502836 | 0.844146 | 0.446610 | 0.473016 | 0.617575 | 0.335738 | ... | 0.546861 | 0.500390 | 0.305585 | 0.673754 | 0.687116 | 0.678637 | 0.734105 | 0.659366 | 1.000000 | 0.600213 |
99 | 0.703313 | 0.781611 | 0.459929 | 0.424052 | 0.741820 | 0.527274 | 0.294579 | 0.578997 | 0.732042 | 0.435869 | ... | 0.818182 | 0.615435 | 0.581559 | 0.645439 | 0.623673 | 0.693008 | 0.581800 | 0.481508 | 0.600213 | 1.000000 |
100 rows × 100 columns
import seaborn as sns
='Greys') sns.heatmap(sim_matrix, cmap
<AxesSubplot:>
# Find the most similar users to user u
def k_nearest_neighbors(A, u, k):
"""Find the k nearest neighbors for user u"""
# Find the index of the user in the matrix
= A.index.get_loc(u)
u_index
# Compute the similarity between the user and all other users
= cosine_similarity(A)
sim_matrix
# Find the k most similar users
= np.argsort(sim_matrix[u_index])[::-1][1:k+1]
k_nearest
# Return the user ids
return A.index[k_nearest]
0, 5) k_nearest_neighbors(A,
Int64Index([28, 46, 90, 32, 87], dtype='int64', name='user_id')
# Show matrix of movie ratings for u and k nearest neighbors
def show_neighbors(A, u, k):
"""Show the movie ratings for user u and k nearest neighbors"""
# Get the user ids of the k nearest neighbors
= k_nearest_neighbors(A, u, k)
neighbors
# Get the movie ratings for user u and the k nearest neighbors
= A.loc[[u] + list(neighbors)]
df
# Return the dataframe
return df
0, 5) show_neighbors(A,
movie_id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
user_id | ||||||||||
0 | 3.0 | 4.0 | 0.0 | 5.0 | 5.0 | 1.0 | 2.0 | 5.0 | 2.0 | 4.0 |
28 | 5.0 | 0.0 | 0.0 | 5.0 | 4.0 | 2.0 | 1.0 | 4.0 | 0.0 | 5.0 |
46 | 3.0 | 0.0 | 2.0 | 5.0 | 5.0 | 1.0 | 0.0 | 4.0 | 2.0 | 2.0 |
90 | 1.0 | 5.0 | 1.0 | 5.0 | 2.0 | 0.0 | 2.0 | 4.0 | 0.0 | 1.0 |
32 | 3.0 | 2.0 | 2.0 | 5.0 | 4.0 | 5.0 | 3.0 | 5.0 | 0.0 | 3.0 |
87 | 1.0 | 0.0 | 0.0 | 5.0 | 4.0 | 0.0 | 3.0 | 4.0 | 4.0 | 2.0 |
# Rating for user u for movie 0 is: (4.0 + 3.0) / 2 = 3.5 (Discard 0s)
def predict_rating(A, u, m, k=5):
"""Predict the rating for user u for movie m"""
# Get the user ids of the k nearest neighbors
= k_nearest_neighbors(A, u, k)
neighbors
# Get the movie ratings for user u and the k nearest neighbors
= A.loc[[u] + list(neighbors)]
df
# Get the ratings for movie m
= df[m]
ratings
# Calculate the mean of the ratings
= ratings[1:][ratings != 0].mean()
mean
# Return the mean
return mean
0, 5) predict_rating(A,
2.6666666666666665
# Now working with real data
# Load the data
= pd.read_excel("mov-rec.xlsx")
df df.head()
Timestamp | Your name | Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2023-04-11 10:58:44.990 | Nipun | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 |
1 | 2023-04-11 10:59:49.617 | Gautam Vashishtha | 3.0 | 4.0 | 4.0 | 5.0 | 3.0 | 1.0 | 5.0 | 5.0 | 4.0 | 3.0 |
2 | 2023-04-11 11:12:44.033 | Eshan Gujarathi | 4.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
3 | 2023-04-11 11:13:48.674 | Sai Krishna Avula | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
4 | 2023-04-11 11:13:55.658 | Ankit Yadav | 3.0 | 3.0 | 2.0 | 5.0 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
# Discard the timestamp column
= df.drop('Timestamp', axis=1)
df
# Make the "Your Name" column the index
= df.set_index('Your name')
df df
Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
---|---|---|---|---|---|---|---|---|---|---|
Your name | ||||||||||
Nipun | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 |
Gautam Vashishtha | 3.0 | 4.0 | 4.0 | 5.0 | 3.0 | 1.0 | 5.0 | 5.0 | 4.0 | 3.0 |
Eshan Gujarathi | 4.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
Sai Krishna Avula | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
Ankit Yadav | 3.0 | 3.0 | 2.0 | 5.0 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
Dhruv | NaN | NaN | 5.0 | 5.0 | 3.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 |
Saatvik Rao | 4.0 | 3.0 | 4.0 | 5.0 | 2.0 | 2.0 | 4.0 | 5.0 | 3.0 | 5.0 |
Zeel B Patel | 5.0 | 4.0 | 5.0 | 4.0 | 4.0 | 4.0 | NaN | 2.0 | NaN | 5.0 |
Neel | 4.0 | NaN | 5.0 | 5.0 | 3.0 | 3.0 | 5.0 | 5.0 | NaN | 4.0 |
Sachin Jalan | 4.0 | NaN | 5.0 | 5.0 | 3.0 | 4.0 | 4.0 | 5.0 | NaN | 3.0 |
Ayush Shrivastava | 5.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | NaN | 4.0 |
.... | 4.0 | 4.0 | NaN | 4.0 | 4.0 | 4.0 | NaN | 4.0 | NaN | 4.0 |
Hari Hara Sudhan | 4.0 | 3.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 5.0 | 3.0 | 5.0 |
Etikikota Hrushikesh | NaN | NaN | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | NaN | NaN |
Chirag | 5.0 | 3.0 | 4.0 | 5.0 | 5.0 | 2.0 | 3.0 | 4.0 | 2.0 | 5.0 |
Aaryan Darad | 5.0 | 4.0 | 4.0 | 5.0 | 3.0 | 4.0 | 3.0 | 3.0 | NaN | 4.0 |
Hetvi Patel | 4.0 | 3.0 | 2.0 | 5.0 | 4.0 | 4.0 | 5.0 | 3.0 | 3.0 | 5.0 |
Kalash Kankaria | 4.0 | NaN | 4.0 | 5.0 | 3.0 | 4.0 | NaN | NaN | NaN | 3.0 |
Rachit Verma | NaN | NaN | 4.0 | 5.0 | 3.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
shriraj | 3.0 | 2.0 | 5.0 | 4.0 | 2.0 | 3.0 | 4.0 | 5.0 | 4.0 | 5.0 |
Bhavini Korthi | NaN | NaN | NaN | NaN | 4.0 | 5.0 | NaN | NaN | NaN | 5.0 |
Hitarth Gandhi | 3.0 | NaN | 4.0 | 5.0 | 3.0 | 4.0 | 5.0 | 5.0 | NaN | NaN |
Radhika Joglekar | 3.0 | 3.0 | 3.0 | 4.0 | 5.0 | 5.0 | 2.0 | 1.0 | 2.0 | 5.0 |
Medhansh Singh | 4.0 | 3.0 | 5.0 | 5.0 | 3.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 |
Arun Mani | NaN | NaN | 4.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | 4.0 | NaN |
Satyam | 3.0 | 5.0 | 5.0 | 5.0 | 4.0 | 3.0 | 5.0 | 5.0 | NaN | 5.0 |
Karan Kumar | 4.0 | 3.0 | 5.0 | 4.0 | 5.0 | 5.0 | 3.0 | 5.0 | 5.0 | 4.0 |
R Yeeshu Dhurandhar | 5.0 | NaN | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | NaN | NaN |
Satyam Gupta | 5.0 | 5.0 | NaN | 5.0 | 4.0 | 4.0 | NaN | 4.0 | NaN | 2.0 |
rushali | NaN | NaN | NaN | 5.0 | 4.0 | 3.0 | NaN | NaN | NaN | NaN |
shridhar | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 3.0 | 3.0 |
Tanvi Jain | 4.0 | 3.0 | NaN | NaN | 4.0 | 5.0 | NaN | NaN | NaN | 5.0 |
Manish Prabhubhai Salvi | 4.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 | NaN | 5.0 |
Varun Barala | 5.0 | 5.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 3.0 | 4.0 |
Kevin Shah | 3.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 3.0 | 5.0 |
Inderjeet | 4.0 | NaN | 4.0 | 5.0 | 4.0 | 3.0 | 5.0 | 5.0 | NaN | 3.0 |
Gangaram Siddam | 4.0 | 4.0 | 3.0 | 3.0 | 5.0 | 5.0 | 4.0 | 4.0 | 3.0 | 5.0 |
Aditi | 4.0 | 4.0 | NaN | 5.0 | 1.0 | 3.0 | 5.0 | 4.0 | NaN | 4.0 |
Madhuri Awachar | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 | 3.0 | 5.0 | 5.0 | 4.0 | 5.0 |
Anupam | 5.0 | 5.0 | NaN | 5.0 | 5.0 | 5.0 | NaN | NaN | NaN | 5.0 |
Jinay | 3.0 | 1.0 | 4.0 | 3.0 | 4.0 | 3.0 | 5.0 | 5.0 | 4.0 | 3.0 |
Shrutimoy | 5.0 | 5.0 | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 2.0 |
Aadesh Desai | 4.0 | 4.0 | 3.0 | 5.0 | 3.0 | 3.0 | 5.0 | 5.0 | 4.0 | 5.0 |
Dhairya | 5.0 | 4.0 | 4.0 | 5.0 | 3.0 | 5.0 | NaN | 4.0 | NaN | 4.0 |
Rahul C | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 4.0 | 4.0 | 5.0 | NaN | NaN |
df.index
Index(['Nipun', 'Gautam Vashishtha', 'Eshan Gujarathi', 'Sai Krishna Avula',
'Ankit Yadav ', 'Dhruv', 'Saatvik Rao ', 'Zeel B Patel', 'Neel ',
'Sachin Jalan ', 'Ayush Shrivastava', '....', 'Hari Hara Sudhan',
'Etikikota Hrushikesh', 'Chirag', 'Aaryan Darad', 'Hetvi Patel',
'Kalash Kankaria', 'Rachit Verma', 'shriraj', 'Bhavini Korthi ',
'Hitarth Gandhi ', 'Radhika Joglekar ', 'Medhansh Singh', 'Arun Mani',
'Satyam ', 'Karan Kumar ', 'R Yeeshu Dhurandhar', 'Satyam Gupta',
'rushali', 'shridhar', 'Tanvi Jain ', 'Manish Prabhubhai Salvi ',
'Varun Barala', 'Kevin Shah ', 'Inderjeet', 'Gangaram Siddam ', 'Aditi',
'Madhuri Awachar', 'Anupam', 'Jinay', 'Shrutimoy', 'Aadesh Desai',
'Dhairya', 'Rahul C'],
dtype='object', name='Your name')
# Get index for user and movie
= 'Rahul C'
user
print(user in df.index)
# Get the movie ratings for user
= df.loc[user]
user_ratings user_ratings
True
Sholay 3.0
Swades (We The People) 3.0
The Matrix (I) 4.0
Interstellar 4.0
Dangal 4.0
Taare Zameen Par 4.0
Shawshank Redemption 4.0
The Dark Knight 5.0
Notting Hill NaN
Uri: The Surgical Strike NaN
Name: Rahul C, dtype: float64
= df.copy()
df_copy 0, inplace=True)
df_copy.fillna(5) show_neighbors(df_copy, user,
Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
---|---|---|---|---|---|---|---|---|---|---|
Your name | ||||||||||
Rahul C | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 4.0 | 4.0 | 5.0 | 0.0 | 0.0 |
Shrutimoy | 5.0 | 5.0 | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | 0.0 | 2.0 |
Hitarth Gandhi | 3.0 | 0.0 | 4.0 | 5.0 | 3.0 | 4.0 | 5.0 | 5.0 | 0.0 | 0.0 |
R Yeeshu Dhurandhar | 5.0 | 0.0 | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 0.0 | 0.0 |
shridhar | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 3.0 | 3.0 |
Sachin Jalan | 4.0 | 0.0 | 5.0 | 5.0 | 3.0 | 4.0 | 4.0 | 5.0 | 0.0 | 3.0 |
df.describe()
Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
---|---|---|---|---|---|---|---|---|---|---|
count | 39.000000 | 32.000000 | 38.000000 | 43.000000 | 45.000000 | 44.000000 | 35.000000 | 40.000000 | 21.000000 | 39.000000 |
mean | 4.102564 | 3.718750 | 4.131579 | 4.581395 | 3.644444 | 3.977273 | 4.400000 | 4.250000 | 3.476190 | 4.230769 |
std | 0.753758 | 0.958304 | 0.991070 | 0.793802 | 1.003529 | 1.067242 | 0.976187 | 1.080123 | 0.813575 | 0.902089 |
min | 3.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 2.000000 | 2.000000 |
25% | 4.000000 | 3.000000 | 4.000000 | 4.000000 | 3.000000 | 3.000000 | 4.000000 | 4.000000 | 3.000000 | 4.000000 |
50% | 4.000000 | 4.000000 | 4.000000 | 5.000000 | 4.000000 | 4.000000 | 5.000000 | 5.000000 | 3.000000 | 4.000000 |
75% | 5.000000 | 4.000000 | 5.000000 | 5.000000 | 4.000000 | 5.000000 | 5.000000 | 5.000000 | 4.000000 | 5.000000 |
max | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 |
# Predict the rating for user u for movie m
'The Dark Knight') predict_rating(df_copy, user,
4.8
'Sholay') predict_rating(df_copy, user,
4.4
# Generic Matrix Factorization (without missing values)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# A is a matrix of size (n_users, n_movies) randomly generated values between 1 and 5
= torch.randint(1, 6, (n_users, n_movies), dtype=torch.float)
A A
tensor([[1., 5., 4., 3., 4., 3., 1., 3., 3., 4.],
[3., 3., 4., 4., 4., 2., 5., 2., 2., 4.],
[2., 5., 5., 3., 3., 1., 4., 5., 2., 5.],
[4., 2., 1., 5., 1., 5., 1., 1., 5., 1.],
[4., 3., 5., 3., 4., 1., 2., 4., 1., 5.],
[1., 2., 2., 1., 2., 4., 1., 1., 3., 3.],
[4., 2., 2., 3., 5., 3., 3., 2., 5., 2.],
[2., 2., 1., 1., 2., 4., 5., 1., 4., 3.],
[1., 3., 2., 1., 3., 4., 1., 5., 5., 1.],
[4., 4., 2., 4., 1., 4., 2., 1., 4., 4.],
[1., 3., 1., 2., 3., 1., 3., 2., 4., 4.],
[5., 3., 5., 1., 4., 5., 4., 2., 1., 5.],
[5., 1., 5., 3., 4., 4., 4., 2., 2., 2.],
[1., 1., 2., 3., 5., 4., 3., 1., 1., 1.],
[3., 2., 5., 4., 3., 5., 2., 4., 4., 3.],
[2., 2., 4., 1., 2., 3., 2., 4., 3., 5.],
[4., 2., 4., 5., 2., 2., 1., 2., 1., 1.],
[4., 3., 2., 1., 1., 1., 5., 2., 1., 1.],
[2., 1., 2., 1., 1., 1., 5., 5., 1., 4.],
[1., 3., 5., 1., 5., 5., 5., 5., 2., 2.],
[1., 2., 2., 5., 1., 1., 5., 4., 2., 5.],
[5., 1., 2., 2., 1., 5., 2., 3., 5., 2.],
[2., 2., 1., 5., 5., 5., 5., 3., 4., 3.],
[1., 1., 4., 1., 2., 5., 5., 5., 2., 1.],
[5., 3., 3., 1., 2., 3., 3., 2., 2., 3.],
[2., 4., 5., 1., 2., 3., 5., 3., 1., 5.],
[2., 1., 1., 1., 2., 2., 5., 5., 2., 1.],
[3., 1., 1., 3., 3., 3., 2., 1., 3., 2.],
[4., 4., 4., 5., 1., 3., 4., 3., 2., 2.],
[4., 4., 4., 5., 2., 5., 2., 1., 4., 3.],
[3., 3., 5., 1., 4., 3., 3., 3., 3., 2.],
[2., 4., 2., 3., 4., 2., 4., 3., 3., 3.],
[5., 4., 1., 5., 4., 3., 5., 1., 3., 4.],
[3., 5., 3., 2., 4., 2., 5., 1., 2., 5.],
[2., 5., 4., 1., 1., 5., 1., 5., 2., 4.],
[1., 3., 3., 4., 4., 5., 2., 5., 3., 5.],
[3., 3., 5., 3., 4., 1., 3., 1., 1., 3.],
[2., 4., 1., 3., 5., 1., 5., 2., 4., 1.],
[5., 2., 2., 3., 1., 4., 5., 5., 4., 2.],
[5., 3., 5., 1., 5., 4., 3., 1., 1., 3.],
[4., 2., 2., 2., 2., 1., 2., 1., 1., 3.],
[2., 2., 1., 4., 1., 4., 5., 2., 5., 1.],
[3., 4., 2., 1., 3., 1., 2., 5., 3., 5.],
[1., 3., 3., 5., 3., 2., 1., 2., 5., 1.],
[3., 4., 2., 4., 2., 3., 4., 1., 1., 1.],
[1., 1., 5., 1., 3., 2., 5., 5., 5., 4.],
[4., 4., 4., 4., 4., 3., 1., 4., 1., 1.],
[3., 4., 5., 4., 1., 5., 2., 3., 1., 3.],
[5., 2., 5., 5., 2., 4., 5., 4., 4., 5.],
[3., 5., 5., 4., 1., 5., 1., 2., 5., 1.],
[2., 4., 3., 5., 4., 5., 2., 5., 3., 3.],
[3., 2., 3., 1., 1., 4., 1., 1., 2., 1.],
[4., 1., 2., 4., 4., 3., 2., 2., 2., 4.],
[2., 2., 3., 2., 3., 2., 2., 5., 4., 3.],
[4., 2., 3., 4., 4., 1., 1., 4., 2., 3.],
[2., 2., 3., 2., 2., 1., 4., 3., 5., 4.],
[5., 5., 5., 4., 2., 1., 3., 1., 3., 3.],
[2., 4., 5., 1., 2., 2., 5., 3., 3., 1.],
[1., 4., 3., 2., 3., 5., 3., 2., 4., 4.],
[1., 3., 2., 1., 2., 3., 2., 2., 5., 2.],
[4., 5., 5., 3., 1., 2., 1., 3., 4., 4.],
[4., 2., 3., 2., 3., 4., 5., 1., 4., 3.],
[3., 2., 1., 3., 2., 2., 5., 5., 5., 1.],
[2., 5., 5., 2., 3., 5., 3., 4., 3., 2.],
[1., 2., 2., 4., 4., 5., 5., 2., 5., 1.],
[4., 4., 1., 5., 2., 4., 4., 2., 4., 2.],
[4., 1., 5., 3., 1., 4., 5., 2., 2., 2.],
[3., 2., 2., 1., 1., 1., 2., 4., 2., 2.],
[1., 1., 3., 2., 2., 4., 3., 3., 2., 2.],
[2., 4., 3., 1., 1., 2., 4., 2., 5., 2.],
[1., 4., 3., 5., 4., 2., 4., 2., 2., 5.],
[4., 4., 5., 5., 4., 3., 4., 1., 4., 5.],
[5., 2., 2., 1., 4., 4., 3., 5., 5., 5.],
[3., 1., 3., 1., 2., 5., 5., 4., 3., 1.],
[5., 5., 5., 3., 5., 3., 3., 3., 3., 3.],
[5., 3., 3., 3., 4., 3., 5., 4., 5., 1.],
[5., 5., 3., 5., 4., 1., 2., 1., 3., 2.],
[5., 3., 4., 1., 5., 2., 4., 4., 2., 1.],
[1., 2., 1., 2., 1., 3., 4., 1., 2., 1.],
[5., 3., 3., 1., 5., 1., 4., 5., 5., 1.],
[2., 2., 5., 2., 3., 5., 1., 5., 4., 2.],
[3., 3., 1., 2., 5., 5., 4., 4., 4., 5.],
[5., 3., 5., 3., 5., 5., 3., 4., 1., 4.],
[5., 3., 2., 5., 1., 2., 4., 3., 1., 4.],
[5., 2., 5., 2., 4., 3., 2., 2., 4., 5.],
[2., 3., 5., 2., 4., 5., 3., 1., 4., 3.],
[3., 3., 1., 4., 5., 2., 1., 3., 4., 4.],
[5., 2., 3., 4., 2., 2., 5., 4., 2., 1.],
[3., 1., 1., 2., 5., 1., 5., 1., 5., 3.],
[3., 5., 1., 5., 3., 4., 5., 5., 4., 3.],
[2., 4., 1., 5., 4., 1., 4., 5., 1., 4.],
[5., 5., 4., 1., 2., 2., 2., 5., 1., 3.],
[5., 2., 3., 3., 4., 3., 3., 4., 3., 2.],
[1., 5., 4., 4., 5., 3., 4., 2., 4., 1.],
[2., 4., 1., 1., 4., 2., 1., 1., 3., 4.],
[4., 4., 2., 3., 2., 2., 3., 5., 4., 1.],
[5., 4., 3., 4., 5., 4., 5., 1., 5., 5.],
[1., 1., 1., 1., 4., 4., 5., 2., 4., 2.],
[2., 4., 1., 2., 4., 3., 5., 1., 4., 4.],
[2., 5., 2., 1., 3., 5., 5., 4., 1., 4.]])
A.shape
torch.Size([100, 10])
Let us decompose A as WH. W is of shape (n, k) and H is of shape (k, n). We can write the above equation as: A = WH
# Randomly initialize A and B
= torch.randn(n_users, 2, requires_grad=True)
W = torch.randn(2, n_movies, requires_grad=True)
H
# Compute the loss
= torch.norm(torch.mm(W, H) - A)
loss loss
tensor(110.7991, grad_fn=<LinalgVectorNormBackward0>)
pd.DataFrame(torch.mm(W, H).detach().numpy())
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | -1.733831 | 2.962563 | -0.009936 | -0.591927 | 2.442282 | -0.533001 | -0.500535 | -0.777075 | -0.427938 | -0.050505 |
1 | -1.605388 | 2.875087 | 0.171515 | -0.770741 | 2.594427 | -0.652586 | -0.512620 | -0.953935 | -0.329276 | -0.023146 |
2 | 0.159360 | -0.289060 | -0.022038 | 0.082685 | -0.266777 | 0.069192 | 0.052250 | 0.101196 | 0.030828 | 0.001642 |
3 | -3.637741 | 4.346515 | -2.580031 | 1.911339 | 0.407374 | 1.134387 | -0.353923 | 1.689442 | -1.846117 | -0.440421 |
4 | 1.706123 | -2.207651 | 0.978520 | -0.611154 | -0.617788 | -0.328235 | 0.228982 | -0.492014 | 0.780051 | 0.176302 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
95 | 1.362280 | -2.751892 | -0.572962 | 1.180665 | -2.989312 | 0.929992 | 0.551276 | 1.363935 | 0.121039 | -0.036218 |
96 | 1.432943 | -1.514237 | 1.287248 | -1.086741 | 0.338909 | -0.685343 | 0.065700 | -1.016967 | 0.827599 | 0.208897 |
97 | 0.381086 | 0.345669 | 1.366953 | -1.551476 | 1.978570 | -1.084161 | -0.261282 | -1.599606 | 0.599750 | 0.189461 |
98 | -3.364117 | 6.450801 | 0.942667 | -2.333749 | 6.511645 | -1.880903 | -1.232884 | -2.755594 | -0.473888 | 0.027722 |
99 | 0.877697 | -1.631227 | -0.175042 | 0.521516 | -1.568212 | 0.428319 | 0.302370 | 0.626961 | 0.149908 | 0.002033 |
100 rows × 10 columns
pd.DataFrame(A)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1.0 | 5.0 | 4.0 | 3.0 | 4.0 | 3.0 | 1.0 | 3.0 | 3.0 | 4.0 |
1 | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 2.0 | 5.0 | 2.0 | 2.0 | 4.0 |
2 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 1.0 | 4.0 | 5.0 | 2.0 | 5.0 |
3 | 4.0 | 2.0 | 1.0 | 5.0 | 1.0 | 5.0 | 1.0 | 1.0 | 5.0 | 1.0 |
4 | 4.0 | 3.0 | 5.0 | 3.0 | 4.0 | 1.0 | 2.0 | 4.0 | 1.0 | 5.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
95 | 4.0 | 4.0 | 2.0 | 3.0 | 2.0 | 2.0 | 3.0 | 5.0 | 4.0 | 1.0 |
96 | 5.0 | 4.0 | 3.0 | 4.0 | 5.0 | 4.0 | 5.0 | 1.0 | 5.0 | 5.0 |
97 | 1.0 | 1.0 | 1.0 | 1.0 | 4.0 | 4.0 | 5.0 | 2.0 | 4.0 | 2.0 |
98 | 2.0 | 4.0 | 1.0 | 2.0 | 4.0 | 3.0 | 5.0 | 1.0 | 4.0 | 4.0 |
99 | 2.0 | 5.0 | 2.0 | 1.0 | 3.0 | 5.0 | 5.0 | 4.0 | 1.0 | 4.0 |
100 rows × 10 columns
# Optimizer
= optim.Adam([W, H], lr=0.01)
optimizer
# Train the model
for i in range(1000):
# Compute the loss
= torch.norm(torch.mm(W, H) - A)
loss
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
# Print the loss
if i % 10 == 0:
print(loss.item())
110.79912567138672
108.5261001586914
106.6722412109375
104.86959075927734
102.57525634765625
99.17333984375
94.17224884033203
87.41124725341797
79.23120880126953
70.51470184326172
62.338897705078125
55.442466735839844
50.24897003173828
46.78135681152344
44.567237854003906
43.125022888183594
42.18255615234375
41.56866455078125
41.16780471801758
40.899253845214844
40.70838928222656
40.56161117553711
40.439727783203125
40.33248519897461
40.23463821411133
40.143455505371094
40.05741500854492
39.9755744934082
39.89726638793945
39.82197952270508
39.749305725097656
39.678890228271484
39.610443115234375
39.54371643066406
39.478515625
39.4146842956543
39.35213088989258
39.290809631347656
39.23072052001953
39.17192077636719
39.114505767822266
39.058589935302734
39.00433349609375
38.95188903808594
38.90142059326172
38.85308837890625
38.8070182800293
38.76333999633789
38.722137451171875
38.6834716796875
38.64735412597656
38.61379623413086
38.5827522277832
38.554161071777344
38.527931213378906
38.503971099853516
38.482154846191406
38.46236038208008
38.444454193115234
38.42829132080078
38.41374588012695
38.40068435668945
38.38897705078125
38.378501892089844
38.36914825439453
38.36080551147461
38.35336685180664
38.34674835205078
38.34086608886719
38.33564376831055
38.330997467041016
38.326881408691406
38.3232307434082
38.31999206542969
38.31712341308594
38.314579010009766
38.31232833862305
38.310325622558594
38.308555603027344
38.30698013305664
38.30558776855469
38.30434799194336
38.30324935913086
38.30227279663086
38.301395416259766
38.300621032714844
38.2999267578125
38.29930114746094
38.29874801635742
38.298248291015625
38.29779815673828
38.297393798828125
38.297027587890625
38.29669189453125
38.29639434814453
38.296119689941406
38.295867919921875
38.29563903808594
38.29542922973633
38.29523468017578
2) pd.DataFrame(torch.mm(W, H).detach().numpy()).head(
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 3.622862 | 3.536414 | 3.945925 | 2.985436 | 2.999869 | 2.749094 | 2.605484 | 2.758482 | 2.209353 | 3.361388 |
1 | 3.622746 | 3.547434 | 3.798462 | 3.133817 | 3.271342 | 3.176287 | 3.221505 | 3.073355 | 2.856927 | 3.383140 |
2) pd.DataFrame(A).head(
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1.0 | 5.0 | 4.0 | 3.0 | 4.0 | 3.0 | 1.0 | 3.0 | 3.0 | 4.0 |
1 | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 2.0 | 5.0 | 2.0 | 2.0 | 4.0 |
def factorize(A, k):
"""Factorize the matrix A into W and H
A: input matrix of size (n_users, n_movies)
k: number of latent features
Returns W and H
W: matrix of size (n_users, k)
H: matrix of size (k, n_movies)
"""
# Randomly initialize W and H
= torch.randn(A.shape[0], k, requires_grad=True)
W = torch.randn(k, A.shape[1], requires_grad=True)
H
# Optimizer
= optim.Adam([W, H], lr=0.01)
optimizer
# Train the model
for i in range(1000):
# Compute the loss
= torch.norm(torch.mm(W, H) - A)
loss
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
return W, H, loss
for k in [1, 2, 3, 4, 5, 6, 9, 10, 11]:
= factorize(A, k)
W, H, loss print(k, loss.item())
1 42.103797912597656
2 38.35541534423828
3 34.45906448364258
4 30.72266387939453
5 27.430004119873047
6 23.318540573120117
9 9.963604927062988
10 0.16205351054668427
11 0.18895108997821808
2) pd.DataFrame(torch.mm(W,H).detach().numpy()).head(
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1.014621 | 5.012642 | 4.012176 | 3.013239 | 4.012656 | 3.010292 | 1.012815 | 3.011410 | 3.015494 | 4.012996 |
1 | 2.987797 | 2.986945 | 3.989305 | 3.987498 | 3.988829 | 1.989375 | 4.989910 | 1.987138 | 1.986990 | 3.987026 |
2) pd.DataFrame(A).head(
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1.0 | 5.0 | 4.0 | 3.0 | 4.0 | 3.0 | 1.0 | 3.0 | 3.0 | 4.0 |
1 | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 2.0 | 5.0 | 2.0 | 2.0 | 4.0 |
# With missing values
# Randomly replace some entries with NaN
= torch.randint(1, 6, (n_users, n_movies), dtype=torch.float)
A < 0.5] = float('nan')
A[torch.rand(A.shape) A
tensor([[nan, 5., 5., nan, 1., 1., 4., 2., nan, 4.],
[nan, 2., 2., 1., nan, nan, nan, nan, 3., nan],
[nan, nan, 5., 1., nan, nan, 2., nan, nan, nan],
[nan, nan, nan, 1., 4., 2., nan, 2., nan, 4.],
[nan, nan, 2., nan, 3., 1., nan, nan, 4., 4.],
[nan, 2., 5., nan, 2., nan, 1., nan, 3., nan],
[nan, nan, nan, nan, 2., 1., nan, 3., nan, 5.],
[nan, 4., nan, 1., 5., nan, 4., 5., 4., nan],
[nan, nan, 2., 5., nan, nan, 5., nan, nan, 2.],
[nan, nan, 3., 2., nan, 1., 1., 4., 5., nan],
[nan, nan, 5., nan, nan, nan, 2., 2., nan, 3.],
[nan, 5., 4., 2., nan, nan, nan, 1., 4., 3.],
[nan, nan, 1., nan, 4., 4., nan, nan, 3., nan],
[nan, 1., nan, nan, 3., nan, nan, nan, nan, 5.],
[4., nan, 2., nan, nan, nan, nan, nan, 4., 3.],
[4., nan, 3., nan, 3., nan, 4., 1., 1., nan],
[5., nan, 3., nan, 3., nan, nan, 1., nan, nan],
[2., 5., nan, 5., 3., 4., 3., 3., 5., 5.],
[3., 3., nan, nan, nan, nan, nan, nan, 4., nan],
[1., nan, 1., 3., 4., 1., nan, nan, 2., nan],
[nan, nan, 5., nan, nan, nan, 2., nan, 1., nan],
[3., 3., nan, nan, 2., 2., 3., 4., 4., nan],
[2., nan, nan, 2., nan, nan, nan, nan, nan, 4.],
[nan, 2., nan, 4., 5., 5., nan, 3., 5., nan],
[2., 2., 2., 4., nan, 4., nan, 1., nan, nan],
[5., 4., 5., 1., nan, 3., 5., 5., 1., 4.],
[2., nan, nan, 2., nan, nan, nan, nan, nan, nan],
[nan, 1., 5., nan, nan, nan, 2., 2., nan, nan],
[2., 3., 4., nan, nan, nan, 1., 4., 4., nan],
[2., nan, 1., 4., 1., 1., nan, 5., nan, 1.],
[5., nan, 3., 1., 5., nan, nan, nan, 2., nan],
[nan, 3., 4., nan, nan, 3., nan, 5., 5., nan],
[4., 3., 2., 3., nan, nan, nan, nan, 1., nan],
[1., 1., nan, 5., 1., 5., nan, nan, 5., nan],
[4., 5., nan, nan, nan, 3., nan, 2., 5., 5.],
[nan, 5., 5., nan, nan, nan, nan, 4., nan, nan],
[5., nan, 4., 2., 3., 3., nan, nan, 2., 1.],
[nan, nan, nan, nan, 5., nan, 2., 2., 4., nan],
[3., 5., 1., 5., 3., nan, 5., nan, 2., 4.],
[nan, 2., 4., nan, 1., 4., nan, 5., nan, 3.],
[5., 5., 2., 5., nan, 4., nan, 2., nan, 3.],
[5., nan, nan, 5., 1., 5., nan, nan, 3., 3.],
[nan, 4., 2., 5., nan, 2., 3., nan, nan, 1.],
[nan, nan, 4., 5., 5., nan, nan, 4., nan, 2.],
[nan, nan, nan, nan, nan, nan, nan, 1., 3., nan],
[3., 3., 1., 1., 1., 4., nan, nan, nan, 2.],
[nan, nan, nan, 4., 5., nan, nan, 3., nan, nan],
[5., nan, 5., 2., 4., nan, nan, nan, 5., nan],
[nan, 1., 1., nan, 3., 1., nan, 1., nan, 1.],
[nan, nan, 1., 2., nan, 3., nan, 2., 2., 4.],
[nan, nan, 5., 1., 3., 2., 2., nan, 1., nan],
[nan, 2., 2., nan, 4., nan, nan, nan, 3., nan],
[nan, nan, nan, nan, 5., 4., nan, 3., nan, nan],
[nan, 2., nan, nan, nan, nan, nan, 3., nan, nan],
[nan, 2., 5., 2., nan, 3., 4., nan, 1., 2.],
[1., nan, 1., nan, nan, 3., 4., 2., nan, 1.],
[nan, 4., 4., 1., nan, nan, 5., nan, nan, 2.],
[nan, nan, 3., nan, nan, 4., 1., nan, 3., nan],
[nan, 1., nan, 3., 2., 3., nan, 2., nan, 4.],
[2., 3., 3., nan, 1., 4., 3., nan, nan, 1.],
[5., nan, 3., 1., 1., nan, 4., nan, 3., nan],
[5., nan, nan, nan, nan, nan, 4., nan, nan, nan],
[2., nan, nan, 2., nan, nan, 2., 3., nan, 4.],
[3., nan, 4., 5., nan, nan, nan, 1., nan, 3.],
[3., 3., 1., nan, 2., 5., 5., nan, 2., nan],
[1., 1., 3., 1., nan, nan, 4., 3., 4., 5.],
[4., nan, nan, 5., 2., nan, 1., nan, nan, nan],
[nan, nan, 2., nan, 5., nan, 2., 1., nan, 4.],
[nan, 5., nan, 4., 3., 2., 2., 1., nan, 4.],
[nan, nan, 1., nan, nan, nan, 2., nan, nan, nan],
[1., nan, nan, 4., 5., 3., nan, 2., nan, nan],
[2., nan, 5., 1., 3., nan, 5., nan, nan, nan],
[3., 5., nan, nan, nan, 5., 2., nan, nan, 2.],
[nan, 2., 5., nan, nan, nan, 2., nan, 1., 2.],
[nan, 3., 2., nan, nan, 4., nan, 2., nan, 5.],
[nan, 4., 5., 2., 5., nan, nan, nan, nan, 5.],
[3., 1., 5., 3., nan, nan, nan, 3., nan, 4.],
[nan, 5., nan, 5., nan, 5., 1., 1., 3., 1.],
[2., nan, nan, nan, 5., nan, nan, 4., nan, 4.],
[nan, 3., 4., 2., nan, 2., nan, 2., nan, nan],
[nan, 5., nan, 2., 2., 4., 5., 5., nan, nan],
[nan, nan, nan, nan, nan, 5., nan, nan, 5., 4.],
[nan, nan, nan, nan, nan, 2., 2., 2., 2., 4.],
[nan, nan, nan, 2., 5., 3., 3., nan, nan, nan],
[1., 2., 4., 2., nan, 2., 5., 4., nan, 5.],
[4., nan, nan, 5., 4., nan, 5., 1., 3., nan],
[4., nan, 1., 4., nan, nan, 2., nan, 4., 3.],
[4., 2., 1., 3., nan, nan, 1., nan, nan, 1.],
[nan, 1., 3., 1., 2., nan, 3., nan, 5., nan],
[nan, 1., 1., nan, 1., 1., nan, 4., nan, nan],
[nan, nan, nan, nan, 1., 5., nan, 5., 3., nan],
[3., 4., nan, 4., 3., nan, 2., 1., nan, nan],
[1., 1., nan, 2., nan, nan, nan, 2., 4., 1.],
[nan, 4., 2., nan, 3., nan, 2., 1., nan, 2.],
[2., nan, 5., nan, 3., nan, 5., 1., nan, nan],
[3., 3., nan, nan, 3., 3., nan, 4., 2., nan],
[2., nan, 5., nan, 5., 3., nan, nan, 5., nan],
[nan, nan, nan, nan, 5., nan, nan, nan, nan, 5.],
[4., nan, nan, 1., 4., 5., 1., nan, 2., 4.],
[1., nan, nan, nan, 2., nan, 4., nan, nan, nan]])
= factorize(A, 2)
W, H, loss loss
tensor(nan, grad_fn=<LinalgVectorNormBackward0>)
As expected, the above function does not work. Our current loss function does not handle missing values.
= ~torch.isnan(A)
mask mask
tensor([[False, True, True, False, True, True, True, True, False, True],
[False, True, True, True, False, False, False, False, True, False],
[False, False, True, True, False, False, True, False, False, False],
[False, False, False, True, True, True, False, True, False, True],
[False, False, True, False, True, True, False, False, True, True],
[False, True, True, False, True, False, True, False, True, False],
[False, False, False, False, True, True, False, True, False, True],
[False, True, False, True, True, False, True, True, True, False],
[False, False, True, True, False, False, True, False, False, True],
[False, False, True, True, False, True, True, True, True, False],
[False, False, True, False, False, False, True, True, False, True],
[False, True, True, True, False, False, False, True, True, True],
[False, False, True, False, True, True, False, False, True, False],
[False, True, False, False, True, False, False, False, False, True],
[ True, False, True, False, False, False, False, False, True, True],
[ True, False, True, False, True, False, True, True, True, False],
[ True, False, True, False, True, False, False, True, False, False],
[ True, True, False, True, True, True, True, True, True, True],
[ True, True, False, False, False, False, False, False, True, False],
[ True, False, True, True, True, True, False, False, True, False],
[False, False, True, False, False, False, True, False, True, False],
[ True, True, False, False, True, True, True, True, True, False],
[ True, False, False, True, False, False, False, False, False, True],
[False, True, False, True, True, True, False, True, True, False],
[ True, True, True, True, False, True, False, True, False, False],
[ True, True, True, True, False, True, True, True, True, True],
[ True, False, False, True, False, False, False, False, False, False],
[False, True, True, False, False, False, True, True, False, False],
[ True, True, True, False, False, False, True, True, True, False],
[ True, False, True, True, True, True, False, True, False, True],
[ True, False, True, True, True, False, False, False, True, False],
[False, True, True, False, False, True, False, True, True, False],
[ True, True, True, True, False, False, False, False, True, False],
[ True, True, False, True, True, True, False, False, True, False],
[ True, True, False, False, False, True, False, True, True, True],
[False, True, True, False, False, False, False, True, False, False],
[ True, False, True, True, True, True, False, False, True, True],
[False, False, False, False, True, False, True, True, True, False],
[ True, True, True, True, True, False, True, False, True, True],
[False, True, True, False, True, True, False, True, False, True],
[ True, True, True, True, False, True, False, True, False, True],
[ True, False, False, True, True, True, False, False, True, True],
[False, True, True, True, False, True, True, False, False, True],
[False, False, True, True, True, False, False, True, False, True],
[False, False, False, False, False, False, False, True, True, False],
[ True, True, True, True, True, True, False, False, False, True],
[False, False, False, True, True, False, False, True, False, False],
[ True, False, True, True, True, False, False, False, True, False],
[False, True, True, False, True, True, False, True, False, True],
[False, False, True, True, False, True, False, True, True, True],
[False, False, True, True, True, True, True, False, True, False],
[False, True, True, False, True, False, False, False, True, False],
[False, False, False, False, True, True, False, True, False, False],
[False, True, False, False, False, False, False, True, False, False],
[False, True, True, True, False, True, True, False, True, True],
[ True, False, True, False, False, True, True, True, False, True],
[False, True, True, True, False, False, True, False, False, True],
[False, False, True, False, False, True, True, False, True, False],
[False, True, False, True, True, True, False, True, False, True],
[ True, True, True, False, True, True, True, False, False, True],
[ True, False, True, True, True, False, True, False, True, False],
[ True, False, False, False, False, False, True, False, False, False],
[ True, False, False, True, False, False, True, True, False, True],
[ True, False, True, True, False, False, False, True, False, True],
[ True, True, True, False, True, True, True, False, True, False],
[ True, True, True, True, False, False, True, True, True, True],
[ True, False, False, True, True, False, True, False, False, False],
[False, False, True, False, True, False, True, True, False, True],
[False, True, False, True, True, True, True, True, False, True],
[False, False, True, False, False, False, True, False, False, False],
[ True, False, False, True, True, True, False, True, False, False],
[ True, False, True, True, True, False, True, False, False, False],
[ True, True, False, False, False, True, True, False, False, True],
[False, True, True, False, False, False, True, False, True, True],
[False, True, True, False, False, True, False, True, False, True],
[False, True, True, True, True, False, False, False, False, True],
[ True, True, True, True, False, False, False, True, False, True],
[False, True, False, True, False, True, True, True, True, True],
[ True, False, False, False, True, False, False, True, False, True],
[False, True, True, True, False, True, False, True, False, False],
[False, True, False, True, True, True, True, True, False, False],
[False, False, False, False, False, True, False, False, True, True],
[False, False, False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, False, False, False],
[ True, True, True, True, False, True, True, True, False, True],
[ True, False, False, True, True, False, True, True, True, False],
[ True, False, True, True, False, False, True, False, True, True],
[ True, True, True, True, False, False, True, False, False, True],
[False, True, True, True, True, False, True, False, True, False],
[False, True, True, False, True, True, False, True, False, False],
[False, False, False, False, True, True, False, True, True, False],
[ True, True, False, True, True, False, True, True, False, False],
[ True, True, False, True, False, False, False, True, True, True],
[False, True, True, False, True, False, True, True, False, True],
[ True, False, True, False, True, False, True, True, False, False],
[ True, True, False, False, True, True, False, True, True, False],
[ True, False, True, False, True, True, False, False, True, False],
[False, False, False, False, True, False, False, False, False, True],
[ True, False, False, True, True, True, True, False, True, True],
[ True, False, False, False, True, False, True, False, False, False]])
sum() mask.
tensor(517)
= torch.randn(A.shape[0], k, requires_grad=True)
W = torch.randn(k, A.shape[1], requires_grad=True)
H
= torch.mm(W, H)-A
diff_matrix diff_matrix.shape
torch.Size([100, 10])
# Mask the matrix
diff_matrix[mask].shape
torch.Size([517])
# Modify the loss function to ignore NaN values
def factorize(A, k):
"""Factorize the matrix D into A and B"""
# Randomly initialize A and B
= torch.randn(A.shape[0], k, requires_grad=True)
W = torch.randn(k, A.shape[1], requires_grad=True)
H # Optimizer
= optim.Adam([W, H], lr=0.01)
optimizer
# Train the model
for i in range(1000):
# Compute the loss
= torch.mm(W, H) - A
diff_matrix = diff_matrix[mask]
diff_vector = torch.norm(diff_vector)
loss
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
return W, H, loss
= factorize(A, 5)
W, H, loss loss
tensor(7.1147, grad_fn=<LinalgVectorNormBackward0>)
torch.mm(W, H)
tensor([[-3.7026e+00, 5.1163e+00, 4.9660e+00, -5.8455e-01, 1.0200e+00,
9.1470e-01, 3.8727e+00, 2.2007e+00, 1.3782e+01, 3.9267e+00],
[ 6.4588e-02, 2.0002e+00, 1.9979e+00, 1.0001e+00, 3.6942e+00,
1.1512e+00, 3.8181e+00, 3.5421e+00, 3.0014e+00, 5.0035e+00],
[-3.3540e-01, -1.0050e+00, 5.0000e+00, 1.0010e+00, 1.7959e-01,
8.6454e-01, 1.9994e+00, 2.9560e+00, 1.0767e+01, -3.4057e-01],
[ 2.6321e+00, 3.5523e+00, 2.1706e+00, 9.9769e-01, 3.9995e+00,
2.0030e+00, 2.8652e+00, 2.0006e+00, 3.5503e-02, 3.9991e+00],
[ 5.3195e-02, -1.3399e+00, 1.9455e+00, 2.2700e+00, 3.0539e+00,
1.0086e+00, 3.8842e+00, 5.1128e+00, 4.0248e+00, 3.9763e+00],
[ 5.5045e+00, 2.0044e+00, 4.9388e+00, 1.6216e+00, 2.0865e+00,
2.8957e+00, 9.2770e-01, 8.8714e-01, 3.0420e+00, -9.2438e-01],
[ 6.1366e-01, -1.3394e+00, -6.3292e+00, 4.0202e+00, 2.0028e+00,
1.0014e+00, 2.1030e+00, 3.0001e+00, -1.0639e+01, 4.9979e+00],
[ 1.6036e+00, 3.6518e+00, 3.7664e+00, 1.2319e+00, 5.1780e+00,
2.1295e+00, 4.7475e+00, 4.0635e+00, 4.0645e+00, 5.9223e+00],
[-4.1593e+00, -1.1450e+01, 2.0064e+00, 5.0104e+00, 1.7877e-01,
-6.8356e-01, 4.9576e+00, 1.0525e+01, 1.2403e+01, 2.0250e+00],
[ 1.7769e+00, -4.9137e+00, 3.1228e+00, 2.2575e+00, 5.8793e-01,
6.6961e-01, 1.2031e+00, 3.7871e+00, 4.9702e+00, -1.4083e+00],
[ 9.1273e+00, 1.4212e+00, 5.0018e+00, -1.0804e+00, 8.9608e+00,
1.6892e+00, 2.0130e+00, 1.9936e+00, -7.2410e+00, 2.9931e+00],
[ 3.2326e+00, 5.2245e+00, 3.8534e+00, 1.8050e+00, 2.7570e+00,
3.2246e+00, 2.7488e+00, 1.4501e+00, 3.9743e+00, 2.7026e+00],
[ 9.2220e-01, 6.5451e+00, 9.9373e-01, 4.2199e+00, 4.0030e+00,
4.0019e+00, 5.8552e+00, 4.3514e+00, 3.0028e+00, 8.0414e+00],
[-2.3552e+00, 1.0013e+00, 9.6726e-01, -4.6631e-01, 3.0005e+00,
-4.1134e-01, 3.3274e+00, 3.0678e+00, 3.0946e+00, 4.9978e+00],
[ 4.0009e+00, 5.9036e+00, 1.9983e+00, 6.0610e+00, 1.0184e+00,
5.6167e+00, 3.5338e+00, 2.5504e+00, 4.0003e+00, 3.0002e+00],
[ 4.1839e+00, 9.1531e+00, 2.5008e+00, 3.1248e+00, 3.4309e+00,
4.8701e+00, 3.6236e+00, 9.7109e-01, 1.3077e+00, 4.9725e+00],
[ 5.0034e+00, -2.6358e+00, 2.9996e+00, -3.1773e-01, 2.9981e+00,
3.3085e-01, -3.0125e-01, 1.0002e+00, -2.6015e+00, -1.5107e+00],
[ 2.1562e+00, 5.1078e+00, 2.3091e+00, 4.3035e+00, 2.3000e+00,
4.1470e+00, 4.2497e+00, 3.3351e+00, 4.7425e+00, 4.5406e+00],
[ 2.9999e+00, 2.9999e+00, 3.0624e+00, 4.4992e+00, 4.5262e+00,
3.8824e+00, 5.4591e+00, 5.5541e+00, 4.0005e+00, 5.9015e+00],
[ 8.1177e-01, -2.9699e+00, 1.2606e+00, 3.0946e+00, 3.9370e+00,
1.0488e+00, 4.3280e+00, 6.3317e+00, 1.8619e+00, 4.5940e+00],
[ 2.1477e+00, 2.4168e+00, 5.0014e+00, -3.6888e+00, 5.9741e+00,
-5.6343e-01, 1.9977e+00, 9.9599e-01, 1.0008e+00, 3.2156e+00],
[ 2.5660e+00, 2.4872e+00, 2.9711e+00, 3.0048e+00, 2.5035e+00,
2.9551e+00, 3.2466e+00, 3.1642e+00, 4.0380e+00, 2.8930e+00],
[ 2.0003e+00, 1.1864e+01, -1.4367e+00, 1.9999e+00, 4.6078e-01,
4.1701e+00, 1.3254e+00, -3.0082e+00, -2.7641e+00, 3.9993e+00],
[ 1.0449e+01, 2.0032e+00, 8.9659e+00, 3.6977e+00, 4.8085e+00,
5.2882e+00, 2.4920e+00, 3.2403e+00, 4.9395e+00, -8.7002e-01],
[ 2.0717e+00, 2.1115e+00, 2.0717e+00, 4.2801e+00, -2.0380e+00,
3.5172e+00, 9.3884e-01, 9.7155e-01, 6.2354e+00, -1.3328e+00],
[ 5.2955e+00, 3.4266e+00, 5.1299e+00, 1.4202e+00, 6.8671e+00,
2.9375e+00, 4.3992e+00, 3.9995e+00, 1.2817e+00, 5.0996e+00],
[ 2.0001e+00, 1.4791e+00, 2.3388e+00, 2.0000e+00, -6.6995e-01,
2.1166e+00, 4.5259e-01, 3.8568e-01, 4.1443e+00, -1.1447e+00],
[ 2.9065e+00, 1.0088e+00, 5.0003e+00, -4.4267e-01, 3.6696e+00,
1.0351e+00, 1.9735e+00, 2.0205e+00, 3.8361e+00, 1.3199e+00],
[ 2.3390e+00, 2.3721e+00, 3.7185e+00, 1.3874e+00, 2.8777e+00,
2.1103e+00, 2.7381e+00, 2.4701e+00, 4.1805e+00, 2.4318e+00],
[ 1.7971e+00, -4.8495e+00, 9.1239e-01, 3.9631e+00, 1.2733e+00,
1.2781e+00, 2.2735e+00, 5.0206e+00, 1.8186e+00, 7.7334e-01],
[ 4.4556e+00, 1.2257e+01, 3.8113e+00, 1.3722e+00, 4.8218e+00,
4.8774e+00, 3.8956e+00, -1.3451e-01, 1.5725e+00, 6.1126e+00],
[ 2.3871e+00, 3.0010e+00, 3.9991e+00, 2.8334e+00, 4.8395e+00,
2.9996e+00, 5.1329e+00, 5.0009e+00, 5.0012e+00, 5.6052e+00],
[ 3.9920e+00, 3.0004e+00, 2.0131e+00, 3.0049e+00, 1.3189e+00,
3.2595e+00, 1.5054e+00, 1.0929e+00, 9.9195e-01, 8.6129e-01],
[ 1.3517e+00, 1.3225e+00, 1.2075e+00, 5.6267e+00, 9.9029e-01,
3.6713e+00, 4.0205e+00, 4.6894e+00, 5.0908e+00, 3.3918e+00],
[ 3.9141e+00, 5.4294e+00, 6.3565e+00, 3.9410e-02, 5.9122e+00,
2.6582e+00, 4.0295e+00, 2.6446e+00, 4.8764e+00, 4.5279e+00],
[ 5.9703e+00, 5.0000e+00, 5.0006e+00, 5.4552e+00, 3.4764e+00,
5.7107e+00, 4.4197e+00, 3.9984e+00, 5.4614e+00, 3.3112e+00],
[ 4.9808e+00, 2.5274e+00, 4.0592e+00, 2.0167e+00, 2.9650e+00,
2.9926e+00, 1.9145e+00, 1.7219e+00, 1.9706e+00, 1.0164e+00],
[ 3.3675e+00, 7.9931e-01, 6.4143e+00, -1.9437e+00, 5.0061e+00,
3.9840e-01, 1.9918e+00, 2.0016e+00, 4.0036e+00, 1.4088e+00],
[ 3.1524e+00, 4.8411e+00, 1.3102e+00, 5.0749e+00, 2.6352e+00,
4.5228e+00, 4.2075e+00, 3.4787e+00, 2.0158e+00, 4.7858e+00],
[ 1.1647e+00, 2.0442e+00, 3.9774e+00, 5.2040e+00, 1.0064e+00,
3.9541e+00, 4.5550e+00, 5.0695e+00, 1.0370e+01, 2.9545e+00],
[ 4.9155e+00, 4.8887e+00, 1.8611e+00, 4.5726e+00, 2.4214e+00,
4.6615e+00, 2.9405e+00, 2.1223e+00, 4.9785e-01, 2.9426e+00],
[ 4.6095e+00, 7.0277e+00, 2.3640e+00, 4.9826e+00, 1.3639e+00,
5.4668e+00, 2.9430e+00, 1.3600e+00, 2.9275e+00, 2.6572e+00],
[-9.4300e+00, 3.9400e+00, 1.8115e+00, 4.6790e+00, -7.7135e+00,
2.3401e+00, 3.3902e+00, 2.5588e+00, 2.3759e+01, 7.8572e-01],
[ 8.8489e+00, 1.3677e+00, 4.0011e+00, 4.9984e+00, 5.0002e+00,
4.8811e+00, 3.1127e+00, 4.0020e+00, -1.6249e+00, 1.9996e+00],
[-5.3638e+00, -1.6611e+01, -1.2622e+00, -2.9465e+00, -5.3122e+00,
-6.6081e+00, -4.5960e+00, 1.0021e+00, 3.0004e+00, -8.7048e+00],
[ 2.9505e+00, 3.4944e+00, 1.0966e+00, 2.0104e+00, 1.4645e+00,
2.5452e+00, 1.3437e+00, 4.9608e-01, -3.2980e-01, 1.5523e+00],
[ 8.8610e+00, 3.4098e+00, 5.4534e+00, 4.0010e+00, 5.0039e+00,
5.0527e+00, 2.9819e+00, 2.9969e+00, 4.8445e-01, 1.7580e+00],
[ 4.7645e+00, 7.2438e+00, 5.3527e+00, 2.1598e+00, 3.9257e+00,
4.3363e+00, 3.5150e+00, 1.6544e+00, 4.8212e+00, 3.4858e+00],
[ 3.6301e+00, 8.2968e-01, 1.0952e+00, 5.1852e-01, 2.9273e+00,
1.1501e+00, 7.9266e-01, 7.6517e-01, -3.4503e+00, 1.2149e+00],
[ 1.2349e+00, 4.7494e+00, 1.1553e+00, 2.2902e+00, 2.1858e+00,
2.6806e+00, 3.0526e+00, 1.8136e+00, 1.9800e+00, 4.0943e+00],
[ 5.4207e+00, 4.6174e-01, 4.6315e+00, 1.0934e+00, 3.5323e+00,
2.1171e+00, 1.4216e+00, 1.9254e+00, 1.2497e+00, 1.5768e-01],
[ 8.0981e-01, 1.9982e+00, 1.9823e+00, 2.3342e+00, 4.0156e+00,
1.9755e+00, 4.5111e+00, 4.4963e+00, 3.0107e+00, 5.5846e+00],
[ 3.4230e+00, 6.7811e+00, 1.3001e+00, 3.2712e+00, 5.0004e+00,
4.0008e+00, 4.7148e+00, 2.9985e+00, -8.9564e-01, 6.9733e+00],
[ 4.3992e+00, 1.9999e+00, 1.7722e+00, 3.3742e+00, 3.4675e+00,
3.1595e+00, 2.8761e+00, 3.0001e+00, -8.2256e-01, 3.0558e+00],
[ 6.3852e+00, 1.8676e+00, 5.1417e+00, 2.1083e+00, 5.4183e+00,
3.1672e+00, 3.1282e+00, 3.3945e+00, 1.0754e+00, 2.5551e+00],
[ 8.1366e-01, 2.6100e+00, 1.2431e+00, 4.6053e+00, -7.3305e-01,
3.4299e+00, 2.5706e+00, 2.4739e+00, 5.9303e+00, 1.5624e+00],
[-1.5185e+01, 3.9601e+00, 4.1178e+00, 1.1159e+00, -7.8194e+00,
-2.6153e-01, 4.5528e+00, 3.2349e+00, 3.2557e+01, 2.2662e+00],
[ 4.1539e+00, 6.2482e+00, 3.0004e+00, 2.5108e+00, 4.1870e-01,
4.0017e+00, 9.9885e-01, -7.3851e-01, 2.9998e+00, 1.7461e-01],
[ 2.4292e+00, 1.3324e+00, -2.4186e+00, 3.4744e+00, 2.2288e+00,
2.2009e+00, 2.0940e+00, 2.1399e+00, -5.7555e+00, 3.6885e+00],
[ 2.2950e+00, 2.9436e+00, 3.1712e+00, 4.1268e+00, 5.4157e-01,
3.7216e+00, 2.8408e+00, 2.7181e+00, 6.7460e+00, 1.4747e+00],
[ 5.0664e+00, 4.5652e+01, 2.7328e+00, 1.0374e+00, 1.2891e+00,
1.3143e+01, 3.7184e+00, -1.3608e+01, 3.1697e+00, 1.1363e+01],
[ 5.0006e+00, 6.9239e+00, 9.1918e-01, 7.2628e+00, 1.5445e+00,
6.5242e+00, 4.0004e+00, 2.7856e+00, 1.5080e+00, 4.0918e+00],
[ 2.0015e+00, -1.7211e+00, -2.6718e+00, 1.9970e+00, 3.6355e+00,
4.8944e-01, 2.0002e+00, 3.0038e+00, -7.9397e+00, 3.9988e+00],
[ 2.9915e+00, 1.0068e+01, 4.0126e+00, 5.0140e+00, 2.1576e-01,
6.2162e+00, 3.6692e+00, 9.8564e-01, 9.0477e+00, 2.9998e+00],
[ 3.1704e+00, 3.0173e+00, 6.3333e-01, 6.8363e+00, 2.2725e+00,
4.9317e+00, 4.7652e+00, 5.0099e+00, 2.2206e+00, 4.9007e+00],
[ 7.5330e-01, 1.4262e+00, 3.0599e+00, 8.5747e-01, 4.0712e+00,
1.1776e+00, 3.8169e+00, 3.8298e+00, 3.8094e+00, 4.5052e+00],
[ 4.0428e+00, -2.0650e+01, 2.4385e+00, 4.9425e+00, 1.9279e+00,
-1.5739e+00, 1.0876e+00, 1.0297e+01, 9.8782e-01, -4.4288e+00],
[ 1.0773e+00, 2.2150e+00, 2.0096e+00, -2.4710e+00, 4.9950e+00,
-5.0779e-01, 1.9650e+00, 1.0100e+00, -1.8899e+00, 4.0205e+00],
[ 3.4996e+00, 4.5594e+00, -2.5208e+00, 3.3402e+00, 2.6908e+00,
3.0928e+00, 2.0393e+00, 8.1359e-01, -7.3751e+00, 4.3822e+00],
[-4.8442e+00, 3.4518e+00, 9.9972e-01, -1.2881e+00, -3.5755e-01,
-5.3701e-01, 1.9997e+00, 4.9337e-01, 7.5069e+00, 3.0513e+00],
[ 1.0026e+00, 7.6105e+00, -7.1104e+00, 3.9509e+00, 4.9733e+00,
3.0444e+00, 4.5575e+00, 2.0356e+00, -1.3359e+01, 1.0958e+01],
[ 1.9708e+00, 2.2459e+01, 5.0115e+00, 1.0357e+00, 3.0386e+00,
7.0322e+00, 4.9501e+00, -3.0939e+00, 8.5776e+00, 8.3022e+00],
[ 3.0000e+00, 5.0024e+00, -1.4123e+00, 6.3346e+00, -9.3349e-01,
4.9944e+00, 2.0182e+00, 1.1012e+00, -1.8768e-01, 1.9884e+00],
[ 4.4226e+00, 2.0069e+00, 5.0098e+00, -8.1680e-01, 5.0611e+00,
1.2995e+00, 1.9966e+00, 1.6425e+00, 9.9608e-01, 1.9906e+00],
[ 1.1858e+01, 3.1806e+00, 1.9399e+00, 2.0242e+00, 9.7272e+00,
3.8295e+00, 2.5867e+00, 2.2628e+00, -1.4185e+01, 4.8150e+00],
[ 3.0139e+00, 4.0015e+00, 4.9996e+00, 2.0000e+00, 5.0004e+00,
3.0423e+00, 4.6969e+00, 4.0977e+00, 5.4217e+00, 4.9981e+00],
[ 3.1070e+00, 1.5608e+00, 4.4961e+00, 2.4194e+00, 3.9092e+00,
2.6510e+00, 3.8237e+00, 4.1656e+00, 5.0091e+00, 3.2677e+00],
[ 3.6123e+00, 5.0130e+00, 1.2513e+00, 5.0780e+00, -6.2890e-01,
4.7738e+00, 1.7050e+00, 7.5555e-01, 2.9369e+00, 7.0034e-01],
[ 1.9988e+00, -1.0755e+00, 2.1010e+00, 4.8154e-01, 5.0034e+00,
4.1238e-01, 3.1198e+00, 4.0017e+00, -6.4504e-01, 3.9979e+00],
[-2.4511e+00, 3.0049e+00, 4.0049e+00, 2.0168e+00, -1.7567e+00,
1.9773e+00, 2.5612e+00, 1.9963e+00, 1.3746e+01, 7.7954e-01],
[ 1.2900e+00, 4.8906e+00, 1.2285e+01, 2.0178e+00, 2.0017e+00,
4.0381e+00, 5.2765e+00, 4.7427e+00, 2.3263e+01, 1.4459e+00],
[ 3.4409e+00, 1.6323e+00, 2.1447e+00, 7.1434e+00, 2.1349e+00,
4.9995e+00, 4.9455e+00, 5.8905e+00, 5.0000e+00, 4.0006e+00],
[-8.3402e+00, 1.6671e+00, -9.7906e+00, 7.5550e+00, -7.1860e+00,
1.9594e+00, 2.1880e+00, 1.9448e+00, 1.9841e+00, 3.9175e+00],
[ 6.1855e+00, 1.0277e+00, 6.0344e+00, 2.0010e+00, 5.0009e+00,
2.9985e+00, 3.0008e+00, 3.6213e+00, 3.0914e+00, 1.7165e+00],
[ 9.9264e-01, 2.2125e+00, 3.9075e+00, 1.8425e+00, 4.0571e+00,
2.0497e+00, 4.5854e+00, 4.5379e+00, 6.0676e+00, 4.9087e+00],
[ 4.1464e+00, 1.4062e+01, 2.4284e+00, 4.7693e+00, 3.6919e+00,
6.9404e+00, 5.3224e+00, 1.0755e+00, 2.9501e+00, 7.6420e+00],
[ 3.7800e+00, 1.9199e+01, 1.3223e+00, 4.2259e+00, -1.6001e+00,
7.9395e+00, 1.7065e+00, -4.9606e+00, 3.8810e+00, 3.1246e+00],
[ 4.0234e+00, 2.0156e+00, 9.3090e-01, 2.9352e+00, 1.4458e+00,
2.8364e+00, 1.1944e+00, 1.0778e+00, -1.3489e+00, 8.8513e-01],
[ 2.9993e-01, 1.0106e+00, 2.8743e+00, 1.0596e+00, 2.1897e+00,
1.1528e+00, 2.7986e+00, 2.9140e+00, 5.0939e+00, 2.6367e+00],
[-2.7522e+00, 1.0011e+00, 9.9939e-01, 2.1638e+00, 9.9803e-01,
9.9864e-01, 3.8559e+00, 4.0029e+00, 6.9205e+00, 4.3317e+00],
[ 4.8687e+00, -1.5672e-01, 1.6703e+00, 7.5356e+00, 9.9502e-01,
4.9990e+00, 3.3880e+00, 5.0098e+00, 2.9938e+00, 1.4576e+00],
[ 3.0192e+00, 3.9819e+00, -5.9198e+00, 3.9832e+00, 2.9753e+00,
2.7502e+00, 2.0750e+00, 9.5908e-01, -1.2959e+01, 5.9251e+00],
[ 1.0148e+00, 9.7130e-01, 1.9768e+00, 2.0122e+00, 6.6718e-01,
1.7201e+00, 1.7755e+00, 1.9466e+00, 4.0138e+00, 1.0400e+00],
[ 5.3930e+00, 4.0298e+00, 2.0109e+00, 2.7548e+00, 2.9962e+00,
3.5641e+00, 1.8677e+00, 1.0778e+00, -1.7539e+00, 2.0255e+00],
[ 1.9953e+00, 1.2843e+01, 5.0056e+00, 2.5516e+00, 2.9983e+00,
5.4422e+00, 5.0090e+00, 9.8768e-01, 8.9960e+00, 6.2394e+00],
[ 2.8324e+00, 2.8434e+00, 1.9018e+00, 3.8000e+00, 3.1922e+00,
3.3458e+00, 3.8883e+00, 3.7678e+00, 2.0126e+00, 4.2686e+00],
[ 2.0652e+00, 8.7485e+00, 4.9003e+00, -7.1472e-02, 5.0456e+00,
2.9448e+00, 4.2957e+00, 1.4602e+00, 5.0622e+00, 6.0046e+00],
[ 1.2368e+00, -8.0124e-01, 4.5406e-01, 5.8132e-01, 5.0025e+00,
1.8401e-01, 3.2316e+00, 3.9075e+00, -2.7376e+00, 4.9980e+00],
[ 4.4707e+00, 1.3100e+01, 3.8913e+00, 5.0670e-01, 3.1494e+00,
4.7096e+00, 2.2617e+00, -2.3287e+00, 1.8589e+00, 3.9039e+00],
[ 9.9929e-01, 1.5253e+01, 2.1834e-03, 2.6820e+00, 2.0005e+00,
5.3245e+00, 3.9995e+00, -1.3442e+00, 1.4326e+00, 7.5636e+00]],
grad_fn=<MmBackward0>)
# Now use matrix factorization to predict the ratings
import torch
import torch.nn as nn
import torch.nn.functional as F
# Create a class for the model
class MatrixFactorization(nn.Module):
def __init__(self, n_users, n_movies, n_factors=20):
super().__init__()
self.user_factors = nn.Embedding(n_users, n_factors)
self.movie_factors = nn.Embedding(n_movies, n_factors)
def forward(self, user, movie):
return (self.user_factors(user) * self.movie_factors(movie)).sum(1)
= MatrixFactorization(n_users, n_movies, 2)
model model
MatrixFactorization(
(user_factors): Embedding(100, 2)
(movie_factors): Embedding(10, 2)
)
0]), torch.tensor([2])) model(torch.tensor([
tensor([-0.0271], grad_fn=<SumBackward1>)
0, 2] A[
tensor(5.)
type(A)
torch.Tensor
= ~torch.isnan(A)
mask
# Get the indices of the non-NaN values
= torch.where(mask)
i, j
# Get the values of the non-NaN values
= A[mask]
v
# Store in PyTorch tensors
= i.to(torch.int64)
users = j.to(torch.int64)
movies = v.to(torch.float32) ratings
'user': users, 'movie': movies, 'rating': ratings}) pd.DataFrame({
user | movie | rating | |
---|---|---|---|
0 | 0 | 1 | 5.0 |
1 | 0 | 2 | 5.0 |
2 | 0 | 4 | 1.0 |
3 | 0 | 5 | 1.0 |
4 | 0 | 6 | 4.0 |
... | ... | ... | ... |
512 | 98 | 8 | 2.0 |
513 | 98 | 9 | 4.0 |
514 | 99 | 0 | 1.0 |
515 | 99 | 4 | 2.0 |
516 | 99 | 6 | 4.0 |
517 rows × 3 columns
# Fit the Matrix Factorization model
= MatrixFactorization(n_users, n_movies, 4)
model = optim.Adam(model.parameters(), lr=0.01)
optimizer
for i in range(1000):
# Compute the loss
= model(users, movies)
pred = F.mse_loss(pred, ratings)
loss
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
# Print the loss
if i % 100 == 0:
print(loss.item())
14.604362487792969
4.332712650299072
1.0960761308670044
0.6966323852539062
0.5388827919960022
0.45243579149246216
0.4012693464756012
0.3728969395160675
0.35568001866340637
0.34289655089378357
model(users, movies)
tensor([3.5693, 4.5338, 2.6934, 1.8316, 4.8915, 2.0194, 2.7778, 1.8601, 2.1124,
1.1378, 2.9079, 5.0470, 0.9911, 1.9791, 1.0050, 3.9618, 2.0085, 2.0034,
4.0113, 1.9218, 2.9801, 1.0432, 3.8993, 4.1292, 1.9357, 3.8285, 3.5266,
1.3640, 2.3989, 2.4166, 0.8559, 3.3685, 4.4493, 3.4018, 1.4722, 4.8378,
4.6684, 4.4473, 4.3097, 2.0022, 5.0147, 5.0113, 1.9599, 2.8305, 1.4493,
1.6750, 0.9520, 3.8460, 5.1279, 4.7453, 2.1484, 1.8009, 3.2104, 4.2068,
4.5473, 2.9229, 1.1817, 3.1108, 3.2157, 1.2238, 3.6272, 3.9029, 3.2554,
0.9945, 3.0062, 5.0030, 4.1144, 1.6314, 3.7945, 3.3091, 4.0727, 3.6212,
2.4359, 3.5707, 1.2826, 0.9663, 4.9973, 3.0163, 2.9916, 1.0014, 3.1734,
3.8712, 4.3364, 3.7119, 4.5313, 2.3875, 4.0274, 4.7121, 4.3851, 2.8072,
3.2066, 3.9684, 0.9307, 1.4160, 2.7484, 2.8771, 1.2753, 2.5825, 4.9857,
2.0109, 1.0080, 2.2618, 2.7936, 2.5859, 3.0972, 3.1443, 3.8655, 3.2023,
2.0061, 1.9866, 4.0068, 2.7861, 4.2803, 4.5452, 3.8177, 2.9951, 5.6040,
1.9013, 2.4498, 1.9155, 4.3239, 3.2993, 1.1275, 4.1460, 4.3619, 4.5913,
1.1365, 3.0865, 6.0992, 3.9207, 1.8325, 3.8526, 2.0000, 2.0000, 1.0255,
5.0484, 1.9858, 1.9716, 2.0854, 2.7726, 3.5053, 1.7365, 3.2607, 4.7006,
1.8830, 0.3378, 2.6222, 0.6491, 3.1144, 3.8404, 2.0632, 4.3311, 4.6138,
1.6119, 3.0444, 2.4498, 2.7860, 4.0832, 3.2681, 5.0069, 4.8745, 3.6048,
3.5065, 1.8871, 2.9108, 0.9982, 0.6787, 1.4787, 5.6890, 1.5117, 4.0205,
4.6624, 4.1051, 4.5437, 3.5410, 1.8761, 4.6653, 5.2577, 5.0311, 4.9531,
4.0022, 4.7229, 4.0406, 2.5056, 3.2014, 2.9312, 1.3051, 1.3258, 4.9233,
2.0871, 1.8761, 4.1304, 3.9279, 4.5547, 1.6639, 4.3163, 1.8149, 4.7431,
3.4398, 3.5483, 3.2541, 2.8282, 1.5230, 2.8214, 4.9214, 3.5371, 4.4522,
4.7928, 2.4932, 4.8619, 4.7693, 2.3121, 2.3021, 4.6609, 4.7478, 1.7768,
5.3910, 2.8929, 2.4920, 4.1654, 1.8382, 3.4456, 3.7108, 2.3989, 1.0480,
4.2870, 5.2730, 3.8976, 3.2956, 3.1010, 0.9666, 3.0227, 2.9491, 3.2218,
1.4033, 1.7050, 0.5457, 3.1223, 2.1794, 3.9939, 5.0190, 2.9944, 4.5670,
5.4131, 2.4720, 4.0239, 4.5689, 0.9386, 1.1827, 2.1553, 1.4232, 0.4127,
1.7502, 1.5601, 2.0953, 2.3530, 2.7085, 2.7157, 2.7900, 4.6793, 1.1711,
3.4493, 1.7701, 2.1576, 0.8294, 1.9908, 2.1975, 3.6396, 3.1952, 4.9882,
4.0091, 2.9998, 1.9985, 3.0063, 3.7149, 4.2018, 1.1992, 2.2550, 3.5132,
1.3794, 2.4434, 2.0220, 0.6691, 1.9927, 3.4666, 2.8342, 1.0393, 4.1834,
4.0819, 0.8652, 4.8585, 2.0225, 2.9662, 3.9794, 1.0412, 3.0180, 1.0117,
3.1728, 2.6027, 2.5184, 2.5120, 3.3483, 3.0602, 3.2139, 1.9144, 1.4899,
2.8253, 2.9160, 1.6410, 3.8484, 2.9637, 2.1272, 2.0490, 4.4915, 1.6853,
5.0641, 3.9315, 1.8812, 2.1028, 2.1203, 2.8898, 4.0215, 3.3032, 3.6131,
4.8380, 0.7513, 3.4182, 3.0527, 3.6200, 1.3040, 1.1657, 4.2705, 4.8569,
2.7037, 0.9142, 1.5261, 3.3970, 0.6980, 3.1131, 3.8348, 3.4478, 4.9488,
4.0077, 4.9673, 2.0816, 0.9995, 3.4965, 3.2899, 1.0624, 1.4977, 4.3810,
4.2225, 3.1540, 4.0191, 3.4060, 2.0160, 1.2814, 2.9608, 1.1513, 1.9530,
0.9872, 3.9914, 4.9865, 3.0477, 1.9701, 2.5473, 4.6173, 0.5252, 3.4813,
4.7750, 3.5971, 4.1422, 5.2524, 2.0691, 2.0039, 2.1420, 5.0156, 1.8481,
0.7781, 2.2183, 2.6443, 2.2889, 4.3373, 2.2449, 4.5594, 3.8624, 5.2333,
2.1537, 4.7373, 5.0298, 2.1322, 2.7047, 4.2308, 2.4907, 2.3259, 4.7678,
5.0266, 5.2597, 4.7697, 0.9911, 1.0440, 2.7818, 1.1585, 2.1967, 4.4422,
3.4151, 4.8498, 2.8759, 4.0001, 1.8111, 2.3611, 1.9162, 4.3893, 2.4916,
2.0815, 3.9944, 5.5013, 4.6856, 4.9904, 4.9865, 4.0133, 1.7178, 2.2895,
1.9270, 2.6042, 3.5529, 2.0969, 4.9558, 2.9138, 3.0347, 1.6650, 2.2754,
3.7474, 0.9570, 2.3175, 4.1652, 4.4857, 5.1715, 5.4181, 3.7923, 4.0080,
3.5960, 2.1456, 2.9088, 3.2796, 1.2589, 4.6664, 2.3447, 3.6577, 2.8478,
2.8383, 2.9852, 0.9179, 3.0814, 1.2291, 0.7621, 0.8185, 2.3877, 1.1364,
2.8636, 3.2811, 4.5790, 0.7829, 1.2192, 0.8060, 1.2341, 3.9681, 1.1422,
4.9521, 5.1383, 2.8294, 3.6377, 3.8734, 3.5638, 2.9749, 1.5051, 1.4258,
0.5846, 1.0630, 2.7671, 1.7912, 2.8734, 1.8916, 3.9858, 2.5439, 2.6786,
1.6704, 1.3576, 1.7701, 2.0514, 5.1632, 2.6842, 4.8280, 1.1577, 2.9727,
3.3256, 2.7693, 2.6430, 4.0140, 2.2755, 2.1973, 4.6792, 5.3522, 2.8058,
4.9664, 5.2983, 4.7749, 3.8326, 2.4944, 4.2989, 2.8784, 2.0605, 2.8428,
3.1081, 1.0411, 1.9078, 3.9594], grad_fn=<SumBackward1>)
# Now, let's predict the ratings for our df dataframe
= torch.from_numpy(df.values)
A A.shape
torch.Size([45, 10])
= ~torch.isnan(A)
mask
# Get the indices of the non-NaN values
= torch.where(mask)
i, j
# Get the values of the non-NaN values
= A[mask]
v
# Store in PyTorch tensors
= i.to(torch.int64)
users = j.to(torch.int64)
movies = v.to(torch.float32) ratings
'user': users, 'movie': movies, 'rating': ratings}) pd.DataFrame({
user | movie | rating | |
---|---|---|---|
0 | 0 | 0 | 4.0 |
1 | 0 | 1 | 5.0 |
2 | 0 | 2 | 4.0 |
3 | 0 | 3 | 4.0 |
4 | 0 | 4 | 5.0 |
... | ... | ... | ... |
371 | 44 | 3 | 4.0 |
372 | 44 | 4 | 4.0 |
373 | 44 | 5 | 4.0 |
374 | 44 | 6 | 4.0 |
375 | 44 | 7 | 5.0 |
376 rows × 3 columns
# Fit the Matrix Factorization model
= A.shape[0]
n_users = A.shape[1]
n_movies = MatrixFactorization(n_users, n_movies, 4)
model = optim.Adam(model.parameters(), lr=0.01)
optimizer
for i in range(1000):
# Compute the loss
= model(users, movies)
pred = F.mse_loss(pred, ratings)
loss
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
# Print the loss
if i % 100 == 0:
print(loss.item())
19.889324188232422
3.1148574352264404
0.6727441549301147
0.5543633103370667
0.5081750750541687
0.4629250764846802
0.4147825837135315
0.36878159642219543
0.32987719774246216
0.29975879192352295
# Now, let us predict the ratings for any user and movie from df for which we already have the ratings
= 'Dhruv'
username = 'The Dark Knight'
movie
# Get the user and movie indices
= df.index.get_loc(username)
user_idx = df.columns.get_loc(movie)
movie_idx
# Predict the rating
= model(torch.tensor([user_idx]), torch.tensor([movie_idx]))
pred pred.item(), df.loc[username, movie]
(5.259384632110596, 5.0)
df.loc[username]
Sholay NaN
Swades (We The People) NaN
The Matrix (I) 5.0
Interstellar 5.0
Dangal 3.0
Taare Zameen Par NaN
Shawshank Redemption 5.0
The Dark Knight 5.0
Notting Hill 4.0
Uri: The Surgical Strike 5.0
Name: Dhruv, dtype: float64
# Now, let us predict the ratings for any user and movie from df for which we do not have the ratings
= 'Dhruv'
username = 'Sholay'
movie
# Get the user and movie indices
= df.index.get_loc(username)
user_idx = df.columns.get_loc(movie)
movie_idx
# Predict the rating
= model(torch.tensor([user_idx]), torch.tensor([movie_idx]))
pred pred, df.loc[username, movie]
(tensor([3.7885], grad_fn=<SumBackward1>), nan)
# Complete the matrix
with torch.no_grad():
= pd.DataFrame(model.user_factors.weight @ model.movie_factors.weight.t(), index=df.index, columns=df.columns)
completed_matrix # round to nearest integer
= completed_matrix.round() completed_matrix
completed_matrix.head()
Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
---|---|---|---|---|---|---|---|---|---|---|
Your name | ||||||||||
Nipun | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 |
Gautam Vashishtha | 3.0 | 3.0 | 4.0 | 4.0 | 2.0 | 3.0 | 4.0 | 5.0 | 4.0 | 3.0 |
Eshan Gujarathi | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 |
Sai Krishna Avula | 4.0 | 4.0 | 3.0 | 4.0 | 4.0 | 6.0 | 4.0 | 3.0 | 3.0 | 4.0 |
Ankit Yadav | 3.0 | 2.0 | 3.0 | 4.0 | 3.0 | 5.0 | 4.0 | 3.0 | 3.0 | 4.0 |
df.head()
Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
---|---|---|---|---|---|---|---|---|---|---|
Your name | ||||||||||
Nipun | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 |
Gautam Vashishtha | 3.0 | 4.0 | 4.0 | 5.0 | 3.0 | 1.0 | 5.0 | 5.0 | 4.0 | 3.0 |
Eshan Gujarathi | 4.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
Sai Krishna Avula | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
Ankit Yadav | 3.0 | 3.0 | 2.0 | 5.0 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |