import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format = 'retina'Movie Recommendation using KNN and Matrix Factorization
Movie Recommendation using KNN and Matrix Factorization
# Generate some toy user and movie data
# Number of users
n_users = 100
# Number of movies
n_movies = 10
# Number of ratings
n_ratings = 1000
# Generate random user ids
user_ids = np.random.randint(0, n_users, n_ratings)
# Generate random movie ids
movie_ids = np.random.randint(0, n_movies, n_ratings)
# Generate random ratings
ratings = np.random.randint(1, 6, n_ratings)
# Create a dataframe with the data
df = pd.DataFrame({'user_id': user_ids, 'movie_id': movie_ids, 'rating': ratings})
# We should not have any duplicate ratings for the same user and movie
# Drop any rows that have duplicate user_id and movie_id pairs
df = df.drop_duplicates(['user_id', 'movie_id'])df| user_id | movie_id | rating | |
|---|---|---|---|
| 0 | 66 | 7 | 2 |
| 1 | 54 | 7 | 4 |
| 2 | 38 | 5 | 3 |
| 3 | 56 | 6 | 1 |
| 4 | 4 | 0 | 4 |
| ... | ... | ... | ... |
| 987 | 77 | 8 | 3 |
| 992 | 99 | 3 | 3 |
| 994 | 8 | 5 | 3 |
| 998 | 22 | 2 | 3 |
| 999 | 88 | 9 | 1 |
642 rows × 3 columns
# Create a user-item matrix
A = df.pivot(index='user_id', columns='movie_id', values='rating')
A| movie_id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| user_id | ||||||||||
| 0 | 3.0 | 4.0 | NaN | 5.0 | 5.0 | 1.0 | 2.0 | 5.0 | 2.0 | 4.0 |
| 1 | 4.0 | 2.0 | NaN | 1.0 | NaN | 3.0 | 3.0 | 5.0 | 3.0 | 1.0 |
| 2 | 3.0 | NaN | NaN | 4.0 | 1.0 | 5.0 | NaN | 2.0 | NaN | 2.0 |
| 3 | 4.0 | NaN | 4.0 | 2.0 | NaN | 3.0 | NaN | 2.0 | NaN | 4.0 |
| 4 | 4.0 | 3.0 | 3.0 | 3.0 | NaN | NaN | 4.0 | NaN | NaN | NaN |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 95 | 1.0 | 3.0 | 1.0 | 2.0 | 3.0 | NaN | 1.0 | NaN | 5.0 | NaN |
| 96 | 1.0 | 2.0 | NaN | NaN | NaN | NaN | 3.0 | 2.0 | 2.0 | 5.0 |
| 97 | 3.0 | 1.0 | 4.0 | NaN | 3.0 | 1.0 | NaN | NaN | 5.0 | NaN |
| 98 | 5.0 | 3.0 | NaN | NaN | 2.0 | NaN | 1.0 | 3.0 | 4.0 | 3.0 |
| 99 | 1.0 | 5.0 | 2.0 | 3.0 | NaN | 2.0 | 4.0 | 3.0 | 3.0 | NaN |
100 rows × 10 columns
# Fill in the missing values with zeros
A = A.fillna(0)
A| movie_id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| user_id | ||||||||||
| 0 | 3.0 | 4.0 | 0.0 | 5.0 | 5.0 | 1.0 | 2.0 | 5.0 | 2.0 | 4.0 |
| 1 | 4.0 | 2.0 | 0.0 | 1.0 | 0.0 | 3.0 | 3.0 | 5.0 | 3.0 | 1.0 |
| 2 | 3.0 | 0.0 | 0.0 | 4.0 | 1.0 | 5.0 | 0.0 | 2.0 | 0.0 | 2.0 |
| 3 | 4.0 | 0.0 | 4.0 | 2.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 4.0 |
| 4 | 4.0 | 3.0 | 3.0 | 3.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 95 | 1.0 | 3.0 | 1.0 | 2.0 | 3.0 | 0.0 | 1.0 | 0.0 | 5.0 | 0.0 |
| 96 | 1.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 2.0 | 2.0 | 5.0 |
| 97 | 3.0 | 1.0 | 4.0 | 0.0 | 3.0 | 1.0 | 0.0 | 0.0 | 5.0 | 0.0 |
| 98 | 5.0 | 3.0 | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 3.0 | 4.0 | 3.0 |
| 99 | 1.0 | 5.0 | 2.0 | 3.0 | 0.0 | 2.0 | 4.0 | 3.0 | 3.0 | 0.0 |
100 rows × 10 columns
# Cosine similarity between U1 and U2
# User 1
u1 = A.loc[0]
# User 2
u2 = A.loc[1]
# Compute the dot product
dot = np.dot(u1, u2)
# Compute the L2 norm
norm_u1 = np.linalg.norm(u1)
norm_u2 = np.linalg.norm(u2)
# Compute the cosine similarity
cos_sim = dot / (norm_u1 * norm_u2)
cos_sim0.7174278379758501
# Calculate the cosine similarity between users
from sklearn.metrics.pairwise import cosine_similarity
sim_matrix = cosine_similarity(A)
pd.DataFrame(sim_matrix)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.000000 | 0.717428 | 0.663734 | 0.565794 | 0.547289 | 0.742468 | 0.539472 | 0.599021 | 0.748964 | 0.523832 | ... | 0.876593 | 0.593722 | 0.615664 | 0.722496 | 0.783190 | 0.657754 | 0.665375 | 0.446627 | 0.774667 | 0.703313 |
| 1 | 0.717428 | 1.000000 | 0.650769 | 0.591169 | 0.559964 | 0.585314 | 0.300491 | 0.255039 | 0.693395 | 0.500193 | ... | 0.609392 | 0.313893 | 0.510454 | 0.550459 | 0.624622 | 0.493197 | 0.644346 | 0.476288 | 0.802740 | 0.781611 |
| 2 | 0.663734 | 0.650769 | 1.000000 | 0.758954 | 0.406780 | 0.372046 | 0.416654 | 0.270593 | 0.746685 | 0.560180 | ... | 0.519274 | 0.322243 | 0.772529 | 0.253842 | 0.712485 | 0.257761 | 0.322830 | 0.283373 | 0.441886 | 0.459929 |
| 3 | 0.565794 | 0.591169 | 0.758954 | 1.000000 | 0.549030 | 0.540128 | 0.671775 | 0.572892 | 0.611794 | 0.489225 | ... | 0.424052 | 0.586110 | 0.456327 | 0.506719 | 0.715831 | 0.210494 | 0.506585 | 0.492312 | 0.551652 | 0.424052 |
| 4 | 0.547289 | 0.559964 | 0.406780 | 0.549030 | 1.000000 | 0.354329 | 0.304478 | 0.541185 | 0.821353 | 0.202287 | ... | 0.667638 | 0.527306 | 0.339913 | 0.435159 | 0.582943 | 0.478699 | 0.417780 | 0.450063 | 0.502836 | 0.741820 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 95 | 0.657754 | 0.493197 | 0.257761 | 0.210494 | 0.478699 | 0.519615 | 0.452602 | 0.571548 | 0.551553 | 0.405674 | ... | 0.564076 | 0.556890 | 0.604211 | 0.682793 | 0.661382 | 1.000000 | 0.412568 | 0.796715 | 0.678637 | 0.693008 |
| 96 | 0.665375 | 0.644346 | 0.322830 | 0.506585 | 0.417780 | 0.853538 | 0.359095 | 0.538977 | 0.351369 | 0.191776 | ... | 0.498686 | 0.640033 | 0.190421 | 0.690704 | 0.682163 | 0.412568 | 1.000000 | 0.280141 | 0.734105 | 0.581800 |
| 97 | 0.446627 | 0.476288 | 0.283373 | 0.492312 | 0.450063 | 0.400743 | 0.709211 | 0.532239 | 0.469979 | 0.566223 | ... | 0.262641 | 0.446564 | 0.501441 | 0.689500 | 0.637007 | 0.796715 | 0.280141 | 1.000000 | 0.659366 | 0.481508 |
| 98 | 0.774667 | 0.802740 | 0.441886 | 0.551652 | 0.502836 | 0.844146 | 0.446610 | 0.473016 | 0.617575 | 0.335738 | ... | 0.546861 | 0.500390 | 0.305585 | 0.673754 | 0.687116 | 0.678637 | 0.734105 | 0.659366 | 1.000000 | 0.600213 |
| 99 | 0.703313 | 0.781611 | 0.459929 | 0.424052 | 0.741820 | 0.527274 | 0.294579 | 0.578997 | 0.732042 | 0.435869 | ... | 0.818182 | 0.615435 | 0.581559 | 0.645439 | 0.623673 | 0.693008 | 0.581800 | 0.481508 | 0.600213 | 1.000000 |
100 rows × 100 columns
import seaborn as sns
sns.heatmap(sim_matrix, cmap='Greys')<AxesSubplot:>

# Find the most similar users to user u
def k_nearest_neighbors(A, u, k):
"""Find the k nearest neighbors for user u"""
# Find the index of the user in the matrix
u_index = A.index.get_loc(u)
# Compute the similarity between the user and all other users
sim_matrix = cosine_similarity(A)
# Find the k most similar users
k_nearest = np.argsort(sim_matrix[u_index])[::-1][1:k+1]
# Return the user ids
return A.index[k_nearest]k_nearest_neighbors(A, 0, 5)Int64Index([28, 46, 90, 32, 87], dtype='int64', name='user_id')
# Show matrix of movie ratings for u and k nearest neighbors
def show_neighbors(A, u, k):
"""Show the movie ratings for user u and k nearest neighbors"""
# Get the user ids of the k nearest neighbors
neighbors = k_nearest_neighbors(A, u, k)
# Get the movie ratings for user u and the k nearest neighbors
df = A.loc[[u] + list(neighbors)]
# Return the dataframe
return dfshow_neighbors(A, 0, 5)| movie_id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|---|---|---|---|---|---|---|---|---|---|---|
| user_id | ||||||||||
| 0 | 3.0 | 4.0 | 0.0 | 5.0 | 5.0 | 1.0 | 2.0 | 5.0 | 2.0 | 4.0 |
| 28 | 5.0 | 0.0 | 0.0 | 5.0 | 4.0 | 2.0 | 1.0 | 4.0 | 0.0 | 5.0 |
| 46 | 3.0 | 0.0 | 2.0 | 5.0 | 5.0 | 1.0 | 0.0 | 4.0 | 2.0 | 2.0 |
| 90 | 1.0 | 5.0 | 1.0 | 5.0 | 2.0 | 0.0 | 2.0 | 4.0 | 0.0 | 1.0 |
| 32 | 3.0 | 2.0 | 2.0 | 5.0 | 4.0 | 5.0 | 3.0 | 5.0 | 0.0 | 3.0 |
| 87 | 1.0 | 0.0 | 0.0 | 5.0 | 4.0 | 0.0 | 3.0 | 4.0 | 4.0 | 2.0 |
# Rating for user u for movie 0 is: (4.0 + 3.0) / 2 = 3.5 (Discard 0s)
def predict_rating(A, u, m, k=5):
"""Predict the rating for user u for movie m"""
# Get the user ids of the k nearest neighbors
neighbors = k_nearest_neighbors(A, u, k)
# Get the movie ratings for user u and the k nearest neighbors
df = A.loc[[u] + list(neighbors)]
# Get the ratings for movie m
ratings = df[m]
# Calculate the mean of the ratings
mean = ratings[1:][ratings != 0].mean()
# Return the mean
return meanpredict_rating(A, 0, 5)2.6666666666666665
# Now working with real data
# Load the data
df = pd.read_excel("mov-rec.xlsx")
df.head()| Timestamp | Your name | Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2023-04-11 10:58:44.990 | Nipun | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 |
| 1 | 2023-04-11 10:59:49.617 | Gautam Vashishtha | 3.0 | 4.0 | 4.0 | 5.0 | 3.0 | 1.0 | 5.0 | 5.0 | 4.0 | 3.0 |
| 2 | 2023-04-11 11:12:44.033 | Eshan Gujarathi | 4.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
| 3 | 2023-04-11 11:13:48.674 | Sai Krishna Avula | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
| 4 | 2023-04-11 11:13:55.658 | Ankit Yadav | 3.0 | 3.0 | 2.0 | 5.0 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
# Discard the timestamp column
df = df.drop('Timestamp', axis=1)
# Make the "Your Name" column the index
df = df.set_index('Your name')
df| Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
|---|---|---|---|---|---|---|---|---|---|---|
| Your name | ||||||||||
| Nipun | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 |
| Gautam Vashishtha | 3.0 | 4.0 | 4.0 | 5.0 | 3.0 | 1.0 | 5.0 | 5.0 | 4.0 | 3.0 |
| Eshan Gujarathi | 4.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
| Sai Krishna Avula | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
| Ankit Yadav | 3.0 | 3.0 | 2.0 | 5.0 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
| Dhruv | NaN | NaN | 5.0 | 5.0 | 3.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 |
| Saatvik Rao | 4.0 | 3.0 | 4.0 | 5.0 | 2.0 | 2.0 | 4.0 | 5.0 | 3.0 | 5.0 |
| Zeel B Patel | 5.0 | 4.0 | 5.0 | 4.0 | 4.0 | 4.0 | NaN | 2.0 | NaN | 5.0 |
| Neel | 4.0 | NaN | 5.0 | 5.0 | 3.0 | 3.0 | 5.0 | 5.0 | NaN | 4.0 |
| Sachin Jalan | 4.0 | NaN | 5.0 | 5.0 | 3.0 | 4.0 | 4.0 | 5.0 | NaN | 3.0 |
| Ayush Shrivastava | 5.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | NaN | 4.0 |
| .... | 4.0 | 4.0 | NaN | 4.0 | 4.0 | 4.0 | NaN | 4.0 | NaN | 4.0 |
| Hari Hara Sudhan | 4.0 | 3.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 5.0 | 3.0 | 5.0 |
| Etikikota Hrushikesh | NaN | NaN | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | NaN | NaN |
| Chirag | 5.0 | 3.0 | 4.0 | 5.0 | 5.0 | 2.0 | 3.0 | 4.0 | 2.0 | 5.0 |
| Aaryan Darad | 5.0 | 4.0 | 4.0 | 5.0 | 3.0 | 4.0 | 3.0 | 3.0 | NaN | 4.0 |
| Hetvi Patel | 4.0 | 3.0 | 2.0 | 5.0 | 4.0 | 4.0 | 5.0 | 3.0 | 3.0 | 5.0 |
| Kalash Kankaria | 4.0 | NaN | 4.0 | 5.0 | 3.0 | 4.0 | NaN | NaN | NaN | 3.0 |
| Rachit Verma | NaN | NaN | 4.0 | 5.0 | 3.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
| shriraj | 3.0 | 2.0 | 5.0 | 4.0 | 2.0 | 3.0 | 4.0 | 5.0 | 4.0 | 5.0 |
| Bhavini Korthi | NaN | NaN | NaN | NaN | 4.0 | 5.0 | NaN | NaN | NaN | 5.0 |
| Hitarth Gandhi | 3.0 | NaN | 4.0 | 5.0 | 3.0 | 4.0 | 5.0 | 5.0 | NaN | NaN |
| Radhika Joglekar | 3.0 | 3.0 | 3.0 | 4.0 | 5.0 | 5.0 | 2.0 | 1.0 | 2.0 | 5.0 |
| Medhansh Singh | 4.0 | 3.0 | 5.0 | 5.0 | 3.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 |
| Arun Mani | NaN | NaN | 4.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | 4.0 | NaN |
| Satyam | 3.0 | 5.0 | 5.0 | 5.0 | 4.0 | 3.0 | 5.0 | 5.0 | NaN | 5.0 |
| Karan Kumar | 4.0 | 3.0 | 5.0 | 4.0 | 5.0 | 5.0 | 3.0 | 5.0 | 5.0 | 4.0 |
| R Yeeshu Dhurandhar | 5.0 | NaN | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | NaN | NaN |
| Satyam Gupta | 5.0 | 5.0 | NaN | 5.0 | 4.0 | 4.0 | NaN | 4.0 | NaN | 2.0 |
| rushali | NaN | NaN | NaN | 5.0 | 4.0 | 3.0 | NaN | NaN | NaN | NaN |
| shridhar | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 3.0 | 3.0 |
| Tanvi Jain | 4.0 | 3.0 | NaN | NaN | 4.0 | 5.0 | NaN | NaN | NaN | 5.0 |
| Manish Prabhubhai Salvi | 4.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 | NaN | 5.0 |
| Varun Barala | 5.0 | 5.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 3.0 | 4.0 |
| Kevin Shah | 3.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 3.0 | 5.0 |
| Inderjeet | 4.0 | NaN | 4.0 | 5.0 | 4.0 | 3.0 | 5.0 | 5.0 | NaN | 3.0 |
| Gangaram Siddam | 4.0 | 4.0 | 3.0 | 3.0 | 5.0 | 5.0 | 4.0 | 4.0 | 3.0 | 5.0 |
| Aditi | 4.0 | 4.0 | NaN | 5.0 | 1.0 | 3.0 | 5.0 | 4.0 | NaN | 4.0 |
| Madhuri Awachar | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 | 3.0 | 5.0 | 5.0 | 4.0 | 5.0 |
| Anupam | 5.0 | 5.0 | NaN | 5.0 | 5.0 | 5.0 | NaN | NaN | NaN | 5.0 |
| Jinay | 3.0 | 1.0 | 4.0 | 3.0 | 4.0 | 3.0 | 5.0 | 5.0 | 4.0 | 3.0 |
| Shrutimoy | 5.0 | 5.0 | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 2.0 |
| Aadesh Desai | 4.0 | 4.0 | 3.0 | 5.0 | 3.0 | 3.0 | 5.0 | 5.0 | 4.0 | 5.0 |
| Dhairya | 5.0 | 4.0 | 4.0 | 5.0 | 3.0 | 5.0 | NaN | 4.0 | NaN | 4.0 |
| Rahul C | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 4.0 | 4.0 | 5.0 | NaN | NaN |
df.indexIndex(['Nipun', 'Gautam Vashishtha', 'Eshan Gujarathi', 'Sai Krishna Avula',
'Ankit Yadav ', 'Dhruv', 'Saatvik Rao ', 'Zeel B Patel', 'Neel ',
'Sachin Jalan ', 'Ayush Shrivastava', '....', 'Hari Hara Sudhan',
'Etikikota Hrushikesh', 'Chirag', 'Aaryan Darad', 'Hetvi Patel',
'Kalash Kankaria', 'Rachit Verma', 'shriraj', 'Bhavini Korthi ',
'Hitarth Gandhi ', 'Radhika Joglekar ', 'Medhansh Singh', 'Arun Mani',
'Satyam ', 'Karan Kumar ', 'R Yeeshu Dhurandhar', 'Satyam Gupta',
'rushali', 'shridhar', 'Tanvi Jain ', 'Manish Prabhubhai Salvi ',
'Varun Barala', 'Kevin Shah ', 'Inderjeet', 'Gangaram Siddam ', 'Aditi',
'Madhuri Awachar', 'Anupam', 'Jinay', 'Shrutimoy', 'Aadesh Desai',
'Dhairya', 'Rahul C'],
dtype='object', name='Your name')
# Get index for user and movie
user = 'Rahul C'
print(user in df.index)
# Get the movie ratings for user
user_ratings = df.loc[user]
user_ratingsTrue
Sholay 3.0
Swades (We The People) 3.0
The Matrix (I) 4.0
Interstellar 4.0
Dangal 4.0
Taare Zameen Par 4.0
Shawshank Redemption 4.0
The Dark Knight 5.0
Notting Hill NaN
Uri: The Surgical Strike NaN
Name: Rahul C, dtype: float64
df_copy = df.copy()
df_copy.fillna(0, inplace=True)
show_neighbors(df_copy, user, 5)| Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
|---|---|---|---|---|---|---|---|---|---|---|
| Your name | ||||||||||
| Rahul C | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 4.0 | 4.0 | 5.0 | 0.0 | 0.0 |
| Shrutimoy | 5.0 | 5.0 | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | 0.0 | 2.0 |
| Hitarth Gandhi | 3.0 | 0.0 | 4.0 | 5.0 | 3.0 | 4.0 | 5.0 | 5.0 | 0.0 | 0.0 |
| R Yeeshu Dhurandhar | 5.0 | 0.0 | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 0.0 | 0.0 |
| shridhar | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 3.0 | 3.0 |
| Sachin Jalan | 4.0 | 0.0 | 5.0 | 5.0 | 3.0 | 4.0 | 4.0 | 5.0 | 0.0 | 3.0 |
df.describe()| Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
|---|---|---|---|---|---|---|---|---|---|---|
| count | 39.000000 | 32.000000 | 38.000000 | 43.000000 | 45.000000 | 44.000000 | 35.000000 | 40.000000 | 21.000000 | 39.000000 |
| mean | 4.102564 | 3.718750 | 4.131579 | 4.581395 | 3.644444 | 3.977273 | 4.400000 | 4.250000 | 3.476190 | 4.230769 |
| std | 0.753758 | 0.958304 | 0.991070 | 0.793802 | 1.003529 | 1.067242 | 0.976187 | 1.080123 | 0.813575 | 0.902089 |
| min | 3.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 2.000000 | 2.000000 |
| 25% | 4.000000 | 3.000000 | 4.000000 | 4.000000 | 3.000000 | 3.000000 | 4.000000 | 4.000000 | 3.000000 | 4.000000 |
| 50% | 4.000000 | 4.000000 | 4.000000 | 5.000000 | 4.000000 | 4.000000 | 5.000000 | 5.000000 | 3.000000 | 4.000000 |
| 75% | 5.000000 | 4.000000 | 5.000000 | 5.000000 | 4.000000 | 5.000000 | 5.000000 | 5.000000 | 4.000000 | 5.000000 |
| max | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 | 5.000000 |
# Predict the rating for user u for movie m
predict_rating(df_copy, user, 'The Dark Knight')4.8
predict_rating(df_copy, user, 'Sholay')4.4
# Generic Matrix Factorization (without missing values)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# A is a matrix of size (n_users, n_movies) randomly generated values between 1 and 5
A = torch.randint(1, 6, (n_users, n_movies), dtype=torch.float)
Atensor([[1., 5., 4., 3., 4., 3., 1., 3., 3., 4.],
[3., 3., 4., 4., 4., 2., 5., 2., 2., 4.],
[2., 5., 5., 3., 3., 1., 4., 5., 2., 5.],
[4., 2., 1., 5., 1., 5., 1., 1., 5., 1.],
[4., 3., 5., 3., 4., 1., 2., 4., 1., 5.],
[1., 2., 2., 1., 2., 4., 1., 1., 3., 3.],
[4., 2., 2., 3., 5., 3., 3., 2., 5., 2.],
[2., 2., 1., 1., 2., 4., 5., 1., 4., 3.],
[1., 3., 2., 1., 3., 4., 1., 5., 5., 1.],
[4., 4., 2., 4., 1., 4., 2., 1., 4., 4.],
[1., 3., 1., 2., 3., 1., 3., 2., 4., 4.],
[5., 3., 5., 1., 4., 5., 4., 2., 1., 5.],
[5., 1., 5., 3., 4., 4., 4., 2., 2., 2.],
[1., 1., 2., 3., 5., 4., 3., 1., 1., 1.],
[3., 2., 5., 4., 3., 5., 2., 4., 4., 3.],
[2., 2., 4., 1., 2., 3., 2., 4., 3., 5.],
[4., 2., 4., 5., 2., 2., 1., 2., 1., 1.],
[4., 3., 2., 1., 1., 1., 5., 2., 1., 1.],
[2., 1., 2., 1., 1., 1., 5., 5., 1., 4.],
[1., 3., 5., 1., 5., 5., 5., 5., 2., 2.],
[1., 2., 2., 5., 1., 1., 5., 4., 2., 5.],
[5., 1., 2., 2., 1., 5., 2., 3., 5., 2.],
[2., 2., 1., 5., 5., 5., 5., 3., 4., 3.],
[1., 1., 4., 1., 2., 5., 5., 5., 2., 1.],
[5., 3., 3., 1., 2., 3., 3., 2., 2., 3.],
[2., 4., 5., 1., 2., 3., 5., 3., 1., 5.],
[2., 1., 1., 1., 2., 2., 5., 5., 2., 1.],
[3., 1., 1., 3., 3., 3., 2., 1., 3., 2.],
[4., 4., 4., 5., 1., 3., 4., 3., 2., 2.],
[4., 4., 4., 5., 2., 5., 2., 1., 4., 3.],
[3., 3., 5., 1., 4., 3., 3., 3., 3., 2.],
[2., 4., 2., 3., 4., 2., 4., 3., 3., 3.],
[5., 4., 1., 5., 4., 3., 5., 1., 3., 4.],
[3., 5., 3., 2., 4., 2., 5., 1., 2., 5.],
[2., 5., 4., 1., 1., 5., 1., 5., 2., 4.],
[1., 3., 3., 4., 4., 5., 2., 5., 3., 5.],
[3., 3., 5., 3., 4., 1., 3., 1., 1., 3.],
[2., 4., 1., 3., 5., 1., 5., 2., 4., 1.],
[5., 2., 2., 3., 1., 4., 5., 5., 4., 2.],
[5., 3., 5., 1., 5., 4., 3., 1., 1., 3.],
[4., 2., 2., 2., 2., 1., 2., 1., 1., 3.],
[2., 2., 1., 4., 1., 4., 5., 2., 5., 1.],
[3., 4., 2., 1., 3., 1., 2., 5., 3., 5.],
[1., 3., 3., 5., 3., 2., 1., 2., 5., 1.],
[3., 4., 2., 4., 2., 3., 4., 1., 1., 1.],
[1., 1., 5., 1., 3., 2., 5., 5., 5., 4.],
[4., 4., 4., 4., 4., 3., 1., 4., 1., 1.],
[3., 4., 5., 4., 1., 5., 2., 3., 1., 3.],
[5., 2., 5., 5., 2., 4., 5., 4., 4., 5.],
[3., 5., 5., 4., 1., 5., 1., 2., 5., 1.],
[2., 4., 3., 5., 4., 5., 2., 5., 3., 3.],
[3., 2., 3., 1., 1., 4., 1., 1., 2., 1.],
[4., 1., 2., 4., 4., 3., 2., 2., 2., 4.],
[2., 2., 3., 2., 3., 2., 2., 5., 4., 3.],
[4., 2., 3., 4., 4., 1., 1., 4., 2., 3.],
[2., 2., 3., 2., 2., 1., 4., 3., 5., 4.],
[5., 5., 5., 4., 2., 1., 3., 1., 3., 3.],
[2., 4., 5., 1., 2., 2., 5., 3., 3., 1.],
[1., 4., 3., 2., 3., 5., 3., 2., 4., 4.],
[1., 3., 2., 1., 2., 3., 2., 2., 5., 2.],
[4., 5., 5., 3., 1., 2., 1., 3., 4., 4.],
[4., 2., 3., 2., 3., 4., 5., 1., 4., 3.],
[3., 2., 1., 3., 2., 2., 5., 5., 5., 1.],
[2., 5., 5., 2., 3., 5., 3., 4., 3., 2.],
[1., 2., 2., 4., 4., 5., 5., 2., 5., 1.],
[4., 4., 1., 5., 2., 4., 4., 2., 4., 2.],
[4., 1., 5., 3., 1., 4., 5., 2., 2., 2.],
[3., 2., 2., 1., 1., 1., 2., 4., 2., 2.],
[1., 1., 3., 2., 2., 4., 3., 3., 2., 2.],
[2., 4., 3., 1., 1., 2., 4., 2., 5., 2.],
[1., 4., 3., 5., 4., 2., 4., 2., 2., 5.],
[4., 4., 5., 5., 4., 3., 4., 1., 4., 5.],
[5., 2., 2., 1., 4., 4., 3., 5., 5., 5.],
[3., 1., 3., 1., 2., 5., 5., 4., 3., 1.],
[5., 5., 5., 3., 5., 3., 3., 3., 3., 3.],
[5., 3., 3., 3., 4., 3., 5., 4., 5., 1.],
[5., 5., 3., 5., 4., 1., 2., 1., 3., 2.],
[5., 3., 4., 1., 5., 2., 4., 4., 2., 1.],
[1., 2., 1., 2., 1., 3., 4., 1., 2., 1.],
[5., 3., 3., 1., 5., 1., 4., 5., 5., 1.],
[2., 2., 5., 2., 3., 5., 1., 5., 4., 2.],
[3., 3., 1., 2., 5., 5., 4., 4., 4., 5.],
[5., 3., 5., 3., 5., 5., 3., 4., 1., 4.],
[5., 3., 2., 5., 1., 2., 4., 3., 1., 4.],
[5., 2., 5., 2., 4., 3., 2., 2., 4., 5.],
[2., 3., 5., 2., 4., 5., 3., 1., 4., 3.],
[3., 3., 1., 4., 5., 2., 1., 3., 4., 4.],
[5., 2., 3., 4., 2., 2., 5., 4., 2., 1.],
[3., 1., 1., 2., 5., 1., 5., 1., 5., 3.],
[3., 5., 1., 5., 3., 4., 5., 5., 4., 3.],
[2., 4., 1., 5., 4., 1., 4., 5., 1., 4.],
[5., 5., 4., 1., 2., 2., 2., 5., 1., 3.],
[5., 2., 3., 3., 4., 3., 3., 4., 3., 2.],
[1., 5., 4., 4., 5., 3., 4., 2., 4., 1.],
[2., 4., 1., 1., 4., 2., 1., 1., 3., 4.],
[4., 4., 2., 3., 2., 2., 3., 5., 4., 1.],
[5., 4., 3., 4., 5., 4., 5., 1., 5., 5.],
[1., 1., 1., 1., 4., 4., 5., 2., 4., 2.],
[2., 4., 1., 2., 4., 3., 5., 1., 4., 4.],
[2., 5., 2., 1., 3., 5., 5., 4., 1., 4.]])
A.shapetorch.Size([100, 10])
Let us decompose A as WH. W is of shape (n, k) and H is of shape (k, n). We can write the above equation as: A = WH
# Randomly initialize A and B
W = torch.randn(n_users, 2, requires_grad=True)
H = torch.randn(2, n_movies, requires_grad=True)
# Compute the loss
loss = torch.norm(torch.mm(W, H) - A)
losstensor(110.7991, grad_fn=<LinalgVectorNormBackward0>)
pd.DataFrame(torch.mm(W, H).detach().numpy())| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.733831 | 2.962563 | -0.009936 | -0.591927 | 2.442282 | -0.533001 | -0.500535 | -0.777075 | -0.427938 | -0.050505 |
| 1 | -1.605388 | 2.875087 | 0.171515 | -0.770741 | 2.594427 | -0.652586 | -0.512620 | -0.953935 | -0.329276 | -0.023146 |
| 2 | 0.159360 | -0.289060 | -0.022038 | 0.082685 | -0.266777 | 0.069192 | 0.052250 | 0.101196 | 0.030828 | 0.001642 |
| 3 | -3.637741 | 4.346515 | -2.580031 | 1.911339 | 0.407374 | 1.134387 | -0.353923 | 1.689442 | -1.846117 | -0.440421 |
| 4 | 1.706123 | -2.207651 | 0.978520 | -0.611154 | -0.617788 | -0.328235 | 0.228982 | -0.492014 | 0.780051 | 0.176302 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 95 | 1.362280 | -2.751892 | -0.572962 | 1.180665 | -2.989312 | 0.929992 | 0.551276 | 1.363935 | 0.121039 | -0.036218 |
| 96 | 1.432943 | -1.514237 | 1.287248 | -1.086741 | 0.338909 | -0.685343 | 0.065700 | -1.016967 | 0.827599 | 0.208897 |
| 97 | 0.381086 | 0.345669 | 1.366953 | -1.551476 | 1.978570 | -1.084161 | -0.261282 | -1.599606 | 0.599750 | 0.189461 |
| 98 | -3.364117 | 6.450801 | 0.942667 | -2.333749 | 6.511645 | -1.880903 | -1.232884 | -2.755594 | -0.473888 | 0.027722 |
| 99 | 0.877697 | -1.631227 | -0.175042 | 0.521516 | -1.568212 | 0.428319 | 0.302370 | 0.626961 | 0.149908 | 0.002033 |
100 rows × 10 columns
pd.DataFrame(A)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.0 | 5.0 | 4.0 | 3.0 | 4.0 | 3.0 | 1.0 | 3.0 | 3.0 | 4.0 |
| 1 | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 2.0 | 5.0 | 2.0 | 2.0 | 4.0 |
| 2 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 1.0 | 4.0 | 5.0 | 2.0 | 5.0 |
| 3 | 4.0 | 2.0 | 1.0 | 5.0 | 1.0 | 5.0 | 1.0 | 1.0 | 5.0 | 1.0 |
| 4 | 4.0 | 3.0 | 5.0 | 3.0 | 4.0 | 1.0 | 2.0 | 4.0 | 1.0 | 5.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 95 | 4.0 | 4.0 | 2.0 | 3.0 | 2.0 | 2.0 | 3.0 | 5.0 | 4.0 | 1.0 |
| 96 | 5.0 | 4.0 | 3.0 | 4.0 | 5.0 | 4.0 | 5.0 | 1.0 | 5.0 | 5.0 |
| 97 | 1.0 | 1.0 | 1.0 | 1.0 | 4.0 | 4.0 | 5.0 | 2.0 | 4.0 | 2.0 |
| 98 | 2.0 | 4.0 | 1.0 | 2.0 | 4.0 | 3.0 | 5.0 | 1.0 | 4.0 | 4.0 |
| 99 | 2.0 | 5.0 | 2.0 | 1.0 | 3.0 | 5.0 | 5.0 | 4.0 | 1.0 | 4.0 |
100 rows × 10 columns
# Optimizer
optimizer = optim.Adam([W, H], lr=0.01)
# Train the model
for i in range(1000):
# Compute the loss
loss = torch.norm(torch.mm(W, H) - A)
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
# Print the loss
if i % 10 == 0:
print(loss.item())110.79912567138672
108.5261001586914
106.6722412109375
104.86959075927734
102.57525634765625
99.17333984375
94.17224884033203
87.41124725341797
79.23120880126953
70.51470184326172
62.338897705078125
55.442466735839844
50.24897003173828
46.78135681152344
44.567237854003906
43.125022888183594
42.18255615234375
41.56866455078125
41.16780471801758
40.899253845214844
40.70838928222656
40.56161117553711
40.439727783203125
40.33248519897461
40.23463821411133
40.143455505371094
40.05741500854492
39.9755744934082
39.89726638793945
39.82197952270508
39.749305725097656
39.678890228271484
39.610443115234375
39.54371643066406
39.478515625
39.4146842956543
39.35213088989258
39.290809631347656
39.23072052001953
39.17192077636719
39.114505767822266
39.058589935302734
39.00433349609375
38.95188903808594
38.90142059326172
38.85308837890625
38.8070182800293
38.76333999633789
38.722137451171875
38.6834716796875
38.64735412597656
38.61379623413086
38.5827522277832
38.554161071777344
38.527931213378906
38.503971099853516
38.482154846191406
38.46236038208008
38.444454193115234
38.42829132080078
38.41374588012695
38.40068435668945
38.38897705078125
38.378501892089844
38.36914825439453
38.36080551147461
38.35336685180664
38.34674835205078
38.34086608886719
38.33564376831055
38.330997467041016
38.326881408691406
38.3232307434082
38.31999206542969
38.31712341308594
38.314579010009766
38.31232833862305
38.310325622558594
38.308555603027344
38.30698013305664
38.30558776855469
38.30434799194336
38.30324935913086
38.30227279663086
38.301395416259766
38.300621032714844
38.2999267578125
38.29930114746094
38.29874801635742
38.298248291015625
38.29779815673828
38.297393798828125
38.297027587890625
38.29669189453125
38.29639434814453
38.296119689941406
38.295867919921875
38.29563903808594
38.29542922973633
38.29523468017578
pd.DataFrame(torch.mm(W, H).detach().numpy()).head(2)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 3.622862 | 3.536414 | 3.945925 | 2.985436 | 2.999869 | 2.749094 | 2.605484 | 2.758482 | 2.209353 | 3.361388 |
| 1 | 3.622746 | 3.547434 | 3.798462 | 3.133817 | 3.271342 | 3.176287 | 3.221505 | 3.073355 | 2.856927 | 3.383140 |
pd.DataFrame(A).head(2)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.0 | 5.0 | 4.0 | 3.0 | 4.0 | 3.0 | 1.0 | 3.0 | 3.0 | 4.0 |
| 1 | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 2.0 | 5.0 | 2.0 | 2.0 | 4.0 |
def factorize(A, k):
"""Factorize the matrix A into W and H
A: input matrix of size (n_users, n_movies)
k: number of latent features
Returns W and H
W: matrix of size (n_users, k)
H: matrix of size (k, n_movies)
"""
# Randomly initialize W and H
W = torch.randn(A.shape[0], k, requires_grad=True)
H = torch.randn(k, A.shape[1], requires_grad=True)
# Optimizer
optimizer = optim.Adam([W, H], lr=0.01)
# Train the model
for i in range(1000):
# Compute the loss
loss = torch.norm(torch.mm(W, H) - A)
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
return W, H, lossfor k in [1, 2, 3, 4, 5, 6, 9, 10, 11]:
W, H, loss = factorize(A, k)
print(k, loss.item())1 42.103797912597656
2 38.35541534423828
3 34.45906448364258
4 30.72266387939453
5 27.430004119873047
6 23.318540573120117
9 9.963604927062988
10 0.16205351054668427
11 0.18895108997821808
pd.DataFrame(torch.mm(W,H).detach().numpy()).head(2)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.014621 | 5.012642 | 4.012176 | 3.013239 | 4.012656 | 3.010292 | 1.012815 | 3.011410 | 3.015494 | 4.012996 |
| 1 | 2.987797 | 2.986945 | 3.989305 | 3.987498 | 3.988829 | 1.989375 | 4.989910 | 1.987138 | 1.986990 | 3.987026 |
pd.DataFrame(A).head(2)| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.0 | 5.0 | 4.0 | 3.0 | 4.0 | 3.0 | 1.0 | 3.0 | 3.0 | 4.0 |
| 1 | 3.0 | 3.0 | 4.0 | 4.0 | 4.0 | 2.0 | 5.0 | 2.0 | 2.0 | 4.0 |
# With missing values
# Randomly replace some entries with NaN
A = torch.randint(1, 6, (n_users, n_movies), dtype=torch.float)
A[torch.rand(A.shape) < 0.5] = float('nan')
Atensor([[nan, 5., 5., nan, 1., 1., 4., 2., nan, 4.],
[nan, 2., 2., 1., nan, nan, nan, nan, 3., nan],
[nan, nan, 5., 1., nan, nan, 2., nan, nan, nan],
[nan, nan, nan, 1., 4., 2., nan, 2., nan, 4.],
[nan, nan, 2., nan, 3., 1., nan, nan, 4., 4.],
[nan, 2., 5., nan, 2., nan, 1., nan, 3., nan],
[nan, nan, nan, nan, 2., 1., nan, 3., nan, 5.],
[nan, 4., nan, 1., 5., nan, 4., 5., 4., nan],
[nan, nan, 2., 5., nan, nan, 5., nan, nan, 2.],
[nan, nan, 3., 2., nan, 1., 1., 4., 5., nan],
[nan, nan, 5., nan, nan, nan, 2., 2., nan, 3.],
[nan, 5., 4., 2., nan, nan, nan, 1., 4., 3.],
[nan, nan, 1., nan, 4., 4., nan, nan, 3., nan],
[nan, 1., nan, nan, 3., nan, nan, nan, nan, 5.],
[4., nan, 2., nan, nan, nan, nan, nan, 4., 3.],
[4., nan, 3., nan, 3., nan, 4., 1., 1., nan],
[5., nan, 3., nan, 3., nan, nan, 1., nan, nan],
[2., 5., nan, 5., 3., 4., 3., 3., 5., 5.],
[3., 3., nan, nan, nan, nan, nan, nan, 4., nan],
[1., nan, 1., 3., 4., 1., nan, nan, 2., nan],
[nan, nan, 5., nan, nan, nan, 2., nan, 1., nan],
[3., 3., nan, nan, 2., 2., 3., 4., 4., nan],
[2., nan, nan, 2., nan, nan, nan, nan, nan, 4.],
[nan, 2., nan, 4., 5., 5., nan, 3., 5., nan],
[2., 2., 2., 4., nan, 4., nan, 1., nan, nan],
[5., 4., 5., 1., nan, 3., 5., 5., 1., 4.],
[2., nan, nan, 2., nan, nan, nan, nan, nan, nan],
[nan, 1., 5., nan, nan, nan, 2., 2., nan, nan],
[2., 3., 4., nan, nan, nan, 1., 4., 4., nan],
[2., nan, 1., 4., 1., 1., nan, 5., nan, 1.],
[5., nan, 3., 1., 5., nan, nan, nan, 2., nan],
[nan, 3., 4., nan, nan, 3., nan, 5., 5., nan],
[4., 3., 2., 3., nan, nan, nan, nan, 1., nan],
[1., 1., nan, 5., 1., 5., nan, nan, 5., nan],
[4., 5., nan, nan, nan, 3., nan, 2., 5., 5.],
[nan, 5., 5., nan, nan, nan, nan, 4., nan, nan],
[5., nan, 4., 2., 3., 3., nan, nan, 2., 1.],
[nan, nan, nan, nan, 5., nan, 2., 2., 4., nan],
[3., 5., 1., 5., 3., nan, 5., nan, 2., 4.],
[nan, 2., 4., nan, 1., 4., nan, 5., nan, 3.],
[5., 5., 2., 5., nan, 4., nan, 2., nan, 3.],
[5., nan, nan, 5., 1., 5., nan, nan, 3., 3.],
[nan, 4., 2., 5., nan, 2., 3., nan, nan, 1.],
[nan, nan, 4., 5., 5., nan, nan, 4., nan, 2.],
[nan, nan, nan, nan, nan, nan, nan, 1., 3., nan],
[3., 3., 1., 1., 1., 4., nan, nan, nan, 2.],
[nan, nan, nan, 4., 5., nan, nan, 3., nan, nan],
[5., nan, 5., 2., 4., nan, nan, nan, 5., nan],
[nan, 1., 1., nan, 3., 1., nan, 1., nan, 1.],
[nan, nan, 1., 2., nan, 3., nan, 2., 2., 4.],
[nan, nan, 5., 1., 3., 2., 2., nan, 1., nan],
[nan, 2., 2., nan, 4., nan, nan, nan, 3., nan],
[nan, nan, nan, nan, 5., 4., nan, 3., nan, nan],
[nan, 2., nan, nan, nan, nan, nan, 3., nan, nan],
[nan, 2., 5., 2., nan, 3., 4., nan, 1., 2.],
[1., nan, 1., nan, nan, 3., 4., 2., nan, 1.],
[nan, 4., 4., 1., nan, nan, 5., nan, nan, 2.],
[nan, nan, 3., nan, nan, 4., 1., nan, 3., nan],
[nan, 1., nan, 3., 2., 3., nan, 2., nan, 4.],
[2., 3., 3., nan, 1., 4., 3., nan, nan, 1.],
[5., nan, 3., 1., 1., nan, 4., nan, 3., nan],
[5., nan, nan, nan, nan, nan, 4., nan, nan, nan],
[2., nan, nan, 2., nan, nan, 2., 3., nan, 4.],
[3., nan, 4., 5., nan, nan, nan, 1., nan, 3.],
[3., 3., 1., nan, 2., 5., 5., nan, 2., nan],
[1., 1., 3., 1., nan, nan, 4., 3., 4., 5.],
[4., nan, nan, 5., 2., nan, 1., nan, nan, nan],
[nan, nan, 2., nan, 5., nan, 2., 1., nan, 4.],
[nan, 5., nan, 4., 3., 2., 2., 1., nan, 4.],
[nan, nan, 1., nan, nan, nan, 2., nan, nan, nan],
[1., nan, nan, 4., 5., 3., nan, 2., nan, nan],
[2., nan, 5., 1., 3., nan, 5., nan, nan, nan],
[3., 5., nan, nan, nan, 5., 2., nan, nan, 2.],
[nan, 2., 5., nan, nan, nan, 2., nan, 1., 2.],
[nan, 3., 2., nan, nan, 4., nan, 2., nan, 5.],
[nan, 4., 5., 2., 5., nan, nan, nan, nan, 5.],
[3., 1., 5., 3., nan, nan, nan, 3., nan, 4.],
[nan, 5., nan, 5., nan, 5., 1., 1., 3., 1.],
[2., nan, nan, nan, 5., nan, nan, 4., nan, 4.],
[nan, 3., 4., 2., nan, 2., nan, 2., nan, nan],
[nan, 5., nan, 2., 2., 4., 5., 5., nan, nan],
[nan, nan, nan, nan, nan, 5., nan, nan, 5., 4.],
[nan, nan, nan, nan, nan, 2., 2., 2., 2., 4.],
[nan, nan, nan, 2., 5., 3., 3., nan, nan, nan],
[1., 2., 4., 2., nan, 2., 5., 4., nan, 5.],
[4., nan, nan, 5., 4., nan, 5., 1., 3., nan],
[4., nan, 1., 4., nan, nan, 2., nan, 4., 3.],
[4., 2., 1., 3., nan, nan, 1., nan, nan, 1.],
[nan, 1., 3., 1., 2., nan, 3., nan, 5., nan],
[nan, 1., 1., nan, 1., 1., nan, 4., nan, nan],
[nan, nan, nan, nan, 1., 5., nan, 5., 3., nan],
[3., 4., nan, 4., 3., nan, 2., 1., nan, nan],
[1., 1., nan, 2., nan, nan, nan, 2., 4., 1.],
[nan, 4., 2., nan, 3., nan, 2., 1., nan, 2.],
[2., nan, 5., nan, 3., nan, 5., 1., nan, nan],
[3., 3., nan, nan, 3., 3., nan, 4., 2., nan],
[2., nan, 5., nan, 5., 3., nan, nan, 5., nan],
[nan, nan, nan, nan, 5., nan, nan, nan, nan, 5.],
[4., nan, nan, 1., 4., 5., 1., nan, 2., 4.],
[1., nan, nan, nan, 2., nan, 4., nan, nan, nan]])
W, H, loss = factorize(A, 2)
losstensor(nan, grad_fn=<LinalgVectorNormBackward0>)
As expected, the above function does not work. Our current loss function does not handle missing values.
mask = ~torch.isnan(A)
masktensor([[False, True, True, False, True, True, True, True, False, True],
[False, True, True, True, False, False, False, False, True, False],
[False, False, True, True, False, False, True, False, False, False],
[False, False, False, True, True, True, False, True, False, True],
[False, False, True, False, True, True, False, False, True, True],
[False, True, True, False, True, False, True, False, True, False],
[False, False, False, False, True, True, False, True, False, True],
[False, True, False, True, True, False, True, True, True, False],
[False, False, True, True, False, False, True, False, False, True],
[False, False, True, True, False, True, True, True, True, False],
[False, False, True, False, False, False, True, True, False, True],
[False, True, True, True, False, False, False, True, True, True],
[False, False, True, False, True, True, False, False, True, False],
[False, True, False, False, True, False, False, False, False, True],
[ True, False, True, False, False, False, False, False, True, True],
[ True, False, True, False, True, False, True, True, True, False],
[ True, False, True, False, True, False, False, True, False, False],
[ True, True, False, True, True, True, True, True, True, True],
[ True, True, False, False, False, False, False, False, True, False],
[ True, False, True, True, True, True, False, False, True, False],
[False, False, True, False, False, False, True, False, True, False],
[ True, True, False, False, True, True, True, True, True, False],
[ True, False, False, True, False, False, False, False, False, True],
[False, True, False, True, True, True, False, True, True, False],
[ True, True, True, True, False, True, False, True, False, False],
[ True, True, True, True, False, True, True, True, True, True],
[ True, False, False, True, False, False, False, False, False, False],
[False, True, True, False, False, False, True, True, False, False],
[ True, True, True, False, False, False, True, True, True, False],
[ True, False, True, True, True, True, False, True, False, True],
[ True, False, True, True, True, False, False, False, True, False],
[False, True, True, False, False, True, False, True, True, False],
[ True, True, True, True, False, False, False, False, True, False],
[ True, True, False, True, True, True, False, False, True, False],
[ True, True, False, False, False, True, False, True, True, True],
[False, True, True, False, False, False, False, True, False, False],
[ True, False, True, True, True, True, False, False, True, True],
[False, False, False, False, True, False, True, True, True, False],
[ True, True, True, True, True, False, True, False, True, True],
[False, True, True, False, True, True, False, True, False, True],
[ True, True, True, True, False, True, False, True, False, True],
[ True, False, False, True, True, True, False, False, True, True],
[False, True, True, True, False, True, True, False, False, True],
[False, False, True, True, True, False, False, True, False, True],
[False, False, False, False, False, False, False, True, True, False],
[ True, True, True, True, True, True, False, False, False, True],
[False, False, False, True, True, False, False, True, False, False],
[ True, False, True, True, True, False, False, False, True, False],
[False, True, True, False, True, True, False, True, False, True],
[False, False, True, True, False, True, False, True, True, True],
[False, False, True, True, True, True, True, False, True, False],
[False, True, True, False, True, False, False, False, True, False],
[False, False, False, False, True, True, False, True, False, False],
[False, True, False, False, False, False, False, True, False, False],
[False, True, True, True, False, True, True, False, True, True],
[ True, False, True, False, False, True, True, True, False, True],
[False, True, True, True, False, False, True, False, False, True],
[False, False, True, False, False, True, True, False, True, False],
[False, True, False, True, True, True, False, True, False, True],
[ True, True, True, False, True, True, True, False, False, True],
[ True, False, True, True, True, False, True, False, True, False],
[ True, False, False, False, False, False, True, False, False, False],
[ True, False, False, True, False, False, True, True, False, True],
[ True, False, True, True, False, False, False, True, False, True],
[ True, True, True, False, True, True, True, False, True, False],
[ True, True, True, True, False, False, True, True, True, True],
[ True, False, False, True, True, False, True, False, False, False],
[False, False, True, False, True, False, True, True, False, True],
[False, True, False, True, True, True, True, True, False, True],
[False, False, True, False, False, False, True, False, False, False],
[ True, False, False, True, True, True, False, True, False, False],
[ True, False, True, True, True, False, True, False, False, False],
[ True, True, False, False, False, True, True, False, False, True],
[False, True, True, False, False, False, True, False, True, True],
[False, True, True, False, False, True, False, True, False, True],
[False, True, True, True, True, False, False, False, False, True],
[ True, True, True, True, False, False, False, True, False, True],
[False, True, False, True, False, True, True, True, True, True],
[ True, False, False, False, True, False, False, True, False, True],
[False, True, True, True, False, True, False, True, False, False],
[False, True, False, True, True, True, True, True, False, False],
[False, False, False, False, False, True, False, False, True, True],
[False, False, False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, False, False, False],
[ True, True, True, True, False, True, True, True, False, True],
[ True, False, False, True, True, False, True, True, True, False],
[ True, False, True, True, False, False, True, False, True, True],
[ True, True, True, True, False, False, True, False, False, True],
[False, True, True, True, True, False, True, False, True, False],
[False, True, True, False, True, True, False, True, False, False],
[False, False, False, False, True, True, False, True, True, False],
[ True, True, False, True, True, False, True, True, False, False],
[ True, True, False, True, False, False, False, True, True, True],
[False, True, True, False, True, False, True, True, False, True],
[ True, False, True, False, True, False, True, True, False, False],
[ True, True, False, False, True, True, False, True, True, False],
[ True, False, True, False, True, True, False, False, True, False],
[False, False, False, False, True, False, False, False, False, True],
[ True, False, False, True, True, True, True, False, True, True],
[ True, False, False, False, True, False, True, False, False, False]])
mask.sum()tensor(517)
W = torch.randn(A.shape[0], k, requires_grad=True)
H = torch.randn(k, A.shape[1], requires_grad=True)
diff_matrix = torch.mm(W, H)-A
diff_matrix.shapetorch.Size([100, 10])
# Mask the matrix
diff_matrix[mask].shapetorch.Size([517])
# Modify the loss function to ignore NaN values
def factorize(A, k):
"""Factorize the matrix D into A and B"""
# Randomly initialize A and B
W = torch.randn(A.shape[0], k, requires_grad=True)
H = torch.randn(k, A.shape[1], requires_grad=True)
# Optimizer
optimizer = optim.Adam([W, H], lr=0.01)
# Train the model
for i in range(1000):
# Compute the loss
diff_matrix = torch.mm(W, H) - A
diff_vector = diff_matrix[mask]
loss = torch.norm(diff_vector)
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
return W, H, lossW, H, loss = factorize(A, 5)
losstensor(7.1147, grad_fn=<LinalgVectorNormBackward0>)
torch.mm(W, H)tensor([[-3.7026e+00, 5.1163e+00, 4.9660e+00, -5.8455e-01, 1.0200e+00,
9.1470e-01, 3.8727e+00, 2.2007e+00, 1.3782e+01, 3.9267e+00],
[ 6.4588e-02, 2.0002e+00, 1.9979e+00, 1.0001e+00, 3.6942e+00,
1.1512e+00, 3.8181e+00, 3.5421e+00, 3.0014e+00, 5.0035e+00],
[-3.3540e-01, -1.0050e+00, 5.0000e+00, 1.0010e+00, 1.7959e-01,
8.6454e-01, 1.9994e+00, 2.9560e+00, 1.0767e+01, -3.4057e-01],
[ 2.6321e+00, 3.5523e+00, 2.1706e+00, 9.9769e-01, 3.9995e+00,
2.0030e+00, 2.8652e+00, 2.0006e+00, 3.5503e-02, 3.9991e+00],
[ 5.3195e-02, -1.3399e+00, 1.9455e+00, 2.2700e+00, 3.0539e+00,
1.0086e+00, 3.8842e+00, 5.1128e+00, 4.0248e+00, 3.9763e+00],
[ 5.5045e+00, 2.0044e+00, 4.9388e+00, 1.6216e+00, 2.0865e+00,
2.8957e+00, 9.2770e-01, 8.8714e-01, 3.0420e+00, -9.2438e-01],
[ 6.1366e-01, -1.3394e+00, -6.3292e+00, 4.0202e+00, 2.0028e+00,
1.0014e+00, 2.1030e+00, 3.0001e+00, -1.0639e+01, 4.9979e+00],
[ 1.6036e+00, 3.6518e+00, 3.7664e+00, 1.2319e+00, 5.1780e+00,
2.1295e+00, 4.7475e+00, 4.0635e+00, 4.0645e+00, 5.9223e+00],
[-4.1593e+00, -1.1450e+01, 2.0064e+00, 5.0104e+00, 1.7877e-01,
-6.8356e-01, 4.9576e+00, 1.0525e+01, 1.2403e+01, 2.0250e+00],
[ 1.7769e+00, -4.9137e+00, 3.1228e+00, 2.2575e+00, 5.8793e-01,
6.6961e-01, 1.2031e+00, 3.7871e+00, 4.9702e+00, -1.4083e+00],
[ 9.1273e+00, 1.4212e+00, 5.0018e+00, -1.0804e+00, 8.9608e+00,
1.6892e+00, 2.0130e+00, 1.9936e+00, -7.2410e+00, 2.9931e+00],
[ 3.2326e+00, 5.2245e+00, 3.8534e+00, 1.8050e+00, 2.7570e+00,
3.2246e+00, 2.7488e+00, 1.4501e+00, 3.9743e+00, 2.7026e+00],
[ 9.2220e-01, 6.5451e+00, 9.9373e-01, 4.2199e+00, 4.0030e+00,
4.0019e+00, 5.8552e+00, 4.3514e+00, 3.0028e+00, 8.0414e+00],
[-2.3552e+00, 1.0013e+00, 9.6726e-01, -4.6631e-01, 3.0005e+00,
-4.1134e-01, 3.3274e+00, 3.0678e+00, 3.0946e+00, 4.9978e+00],
[ 4.0009e+00, 5.9036e+00, 1.9983e+00, 6.0610e+00, 1.0184e+00,
5.6167e+00, 3.5338e+00, 2.5504e+00, 4.0003e+00, 3.0002e+00],
[ 4.1839e+00, 9.1531e+00, 2.5008e+00, 3.1248e+00, 3.4309e+00,
4.8701e+00, 3.6236e+00, 9.7109e-01, 1.3077e+00, 4.9725e+00],
[ 5.0034e+00, -2.6358e+00, 2.9996e+00, -3.1773e-01, 2.9981e+00,
3.3085e-01, -3.0125e-01, 1.0002e+00, -2.6015e+00, -1.5107e+00],
[ 2.1562e+00, 5.1078e+00, 2.3091e+00, 4.3035e+00, 2.3000e+00,
4.1470e+00, 4.2497e+00, 3.3351e+00, 4.7425e+00, 4.5406e+00],
[ 2.9999e+00, 2.9999e+00, 3.0624e+00, 4.4992e+00, 4.5262e+00,
3.8824e+00, 5.4591e+00, 5.5541e+00, 4.0005e+00, 5.9015e+00],
[ 8.1177e-01, -2.9699e+00, 1.2606e+00, 3.0946e+00, 3.9370e+00,
1.0488e+00, 4.3280e+00, 6.3317e+00, 1.8619e+00, 4.5940e+00],
[ 2.1477e+00, 2.4168e+00, 5.0014e+00, -3.6888e+00, 5.9741e+00,
-5.6343e-01, 1.9977e+00, 9.9599e-01, 1.0008e+00, 3.2156e+00],
[ 2.5660e+00, 2.4872e+00, 2.9711e+00, 3.0048e+00, 2.5035e+00,
2.9551e+00, 3.2466e+00, 3.1642e+00, 4.0380e+00, 2.8930e+00],
[ 2.0003e+00, 1.1864e+01, -1.4367e+00, 1.9999e+00, 4.6078e-01,
4.1701e+00, 1.3254e+00, -3.0082e+00, -2.7641e+00, 3.9993e+00],
[ 1.0449e+01, 2.0032e+00, 8.9659e+00, 3.6977e+00, 4.8085e+00,
5.2882e+00, 2.4920e+00, 3.2403e+00, 4.9395e+00, -8.7002e-01],
[ 2.0717e+00, 2.1115e+00, 2.0717e+00, 4.2801e+00, -2.0380e+00,
3.5172e+00, 9.3884e-01, 9.7155e-01, 6.2354e+00, -1.3328e+00],
[ 5.2955e+00, 3.4266e+00, 5.1299e+00, 1.4202e+00, 6.8671e+00,
2.9375e+00, 4.3992e+00, 3.9995e+00, 1.2817e+00, 5.0996e+00],
[ 2.0001e+00, 1.4791e+00, 2.3388e+00, 2.0000e+00, -6.6995e-01,
2.1166e+00, 4.5259e-01, 3.8568e-01, 4.1443e+00, -1.1447e+00],
[ 2.9065e+00, 1.0088e+00, 5.0003e+00, -4.4267e-01, 3.6696e+00,
1.0351e+00, 1.9735e+00, 2.0205e+00, 3.8361e+00, 1.3199e+00],
[ 2.3390e+00, 2.3721e+00, 3.7185e+00, 1.3874e+00, 2.8777e+00,
2.1103e+00, 2.7381e+00, 2.4701e+00, 4.1805e+00, 2.4318e+00],
[ 1.7971e+00, -4.8495e+00, 9.1239e-01, 3.9631e+00, 1.2733e+00,
1.2781e+00, 2.2735e+00, 5.0206e+00, 1.8186e+00, 7.7334e-01],
[ 4.4556e+00, 1.2257e+01, 3.8113e+00, 1.3722e+00, 4.8218e+00,
4.8774e+00, 3.8956e+00, -1.3451e-01, 1.5725e+00, 6.1126e+00],
[ 2.3871e+00, 3.0010e+00, 3.9991e+00, 2.8334e+00, 4.8395e+00,
2.9996e+00, 5.1329e+00, 5.0009e+00, 5.0012e+00, 5.6052e+00],
[ 3.9920e+00, 3.0004e+00, 2.0131e+00, 3.0049e+00, 1.3189e+00,
3.2595e+00, 1.5054e+00, 1.0929e+00, 9.9195e-01, 8.6129e-01],
[ 1.3517e+00, 1.3225e+00, 1.2075e+00, 5.6267e+00, 9.9029e-01,
3.6713e+00, 4.0205e+00, 4.6894e+00, 5.0908e+00, 3.3918e+00],
[ 3.9141e+00, 5.4294e+00, 6.3565e+00, 3.9410e-02, 5.9122e+00,
2.6582e+00, 4.0295e+00, 2.6446e+00, 4.8764e+00, 4.5279e+00],
[ 5.9703e+00, 5.0000e+00, 5.0006e+00, 5.4552e+00, 3.4764e+00,
5.7107e+00, 4.4197e+00, 3.9984e+00, 5.4614e+00, 3.3112e+00],
[ 4.9808e+00, 2.5274e+00, 4.0592e+00, 2.0167e+00, 2.9650e+00,
2.9926e+00, 1.9145e+00, 1.7219e+00, 1.9706e+00, 1.0164e+00],
[ 3.3675e+00, 7.9931e-01, 6.4143e+00, -1.9437e+00, 5.0061e+00,
3.9840e-01, 1.9918e+00, 2.0016e+00, 4.0036e+00, 1.4088e+00],
[ 3.1524e+00, 4.8411e+00, 1.3102e+00, 5.0749e+00, 2.6352e+00,
4.5228e+00, 4.2075e+00, 3.4787e+00, 2.0158e+00, 4.7858e+00],
[ 1.1647e+00, 2.0442e+00, 3.9774e+00, 5.2040e+00, 1.0064e+00,
3.9541e+00, 4.5550e+00, 5.0695e+00, 1.0370e+01, 2.9545e+00],
[ 4.9155e+00, 4.8887e+00, 1.8611e+00, 4.5726e+00, 2.4214e+00,
4.6615e+00, 2.9405e+00, 2.1223e+00, 4.9785e-01, 2.9426e+00],
[ 4.6095e+00, 7.0277e+00, 2.3640e+00, 4.9826e+00, 1.3639e+00,
5.4668e+00, 2.9430e+00, 1.3600e+00, 2.9275e+00, 2.6572e+00],
[-9.4300e+00, 3.9400e+00, 1.8115e+00, 4.6790e+00, -7.7135e+00,
2.3401e+00, 3.3902e+00, 2.5588e+00, 2.3759e+01, 7.8572e-01],
[ 8.8489e+00, 1.3677e+00, 4.0011e+00, 4.9984e+00, 5.0002e+00,
4.8811e+00, 3.1127e+00, 4.0020e+00, -1.6249e+00, 1.9996e+00],
[-5.3638e+00, -1.6611e+01, -1.2622e+00, -2.9465e+00, -5.3122e+00,
-6.6081e+00, -4.5960e+00, 1.0021e+00, 3.0004e+00, -8.7048e+00],
[ 2.9505e+00, 3.4944e+00, 1.0966e+00, 2.0104e+00, 1.4645e+00,
2.5452e+00, 1.3437e+00, 4.9608e-01, -3.2980e-01, 1.5523e+00],
[ 8.8610e+00, 3.4098e+00, 5.4534e+00, 4.0010e+00, 5.0039e+00,
5.0527e+00, 2.9819e+00, 2.9969e+00, 4.8445e-01, 1.7580e+00],
[ 4.7645e+00, 7.2438e+00, 5.3527e+00, 2.1598e+00, 3.9257e+00,
4.3363e+00, 3.5150e+00, 1.6544e+00, 4.8212e+00, 3.4858e+00],
[ 3.6301e+00, 8.2968e-01, 1.0952e+00, 5.1852e-01, 2.9273e+00,
1.1501e+00, 7.9266e-01, 7.6517e-01, -3.4503e+00, 1.2149e+00],
[ 1.2349e+00, 4.7494e+00, 1.1553e+00, 2.2902e+00, 2.1858e+00,
2.6806e+00, 3.0526e+00, 1.8136e+00, 1.9800e+00, 4.0943e+00],
[ 5.4207e+00, 4.6174e-01, 4.6315e+00, 1.0934e+00, 3.5323e+00,
2.1171e+00, 1.4216e+00, 1.9254e+00, 1.2497e+00, 1.5768e-01],
[ 8.0981e-01, 1.9982e+00, 1.9823e+00, 2.3342e+00, 4.0156e+00,
1.9755e+00, 4.5111e+00, 4.4963e+00, 3.0107e+00, 5.5846e+00],
[ 3.4230e+00, 6.7811e+00, 1.3001e+00, 3.2712e+00, 5.0004e+00,
4.0008e+00, 4.7148e+00, 2.9985e+00, -8.9564e-01, 6.9733e+00],
[ 4.3992e+00, 1.9999e+00, 1.7722e+00, 3.3742e+00, 3.4675e+00,
3.1595e+00, 2.8761e+00, 3.0001e+00, -8.2256e-01, 3.0558e+00],
[ 6.3852e+00, 1.8676e+00, 5.1417e+00, 2.1083e+00, 5.4183e+00,
3.1672e+00, 3.1282e+00, 3.3945e+00, 1.0754e+00, 2.5551e+00],
[ 8.1366e-01, 2.6100e+00, 1.2431e+00, 4.6053e+00, -7.3305e-01,
3.4299e+00, 2.5706e+00, 2.4739e+00, 5.9303e+00, 1.5624e+00],
[-1.5185e+01, 3.9601e+00, 4.1178e+00, 1.1159e+00, -7.8194e+00,
-2.6153e-01, 4.5528e+00, 3.2349e+00, 3.2557e+01, 2.2662e+00],
[ 4.1539e+00, 6.2482e+00, 3.0004e+00, 2.5108e+00, 4.1870e-01,
4.0017e+00, 9.9885e-01, -7.3851e-01, 2.9998e+00, 1.7461e-01],
[ 2.4292e+00, 1.3324e+00, -2.4186e+00, 3.4744e+00, 2.2288e+00,
2.2009e+00, 2.0940e+00, 2.1399e+00, -5.7555e+00, 3.6885e+00],
[ 2.2950e+00, 2.9436e+00, 3.1712e+00, 4.1268e+00, 5.4157e-01,
3.7216e+00, 2.8408e+00, 2.7181e+00, 6.7460e+00, 1.4747e+00],
[ 5.0664e+00, 4.5652e+01, 2.7328e+00, 1.0374e+00, 1.2891e+00,
1.3143e+01, 3.7184e+00, -1.3608e+01, 3.1697e+00, 1.1363e+01],
[ 5.0006e+00, 6.9239e+00, 9.1918e-01, 7.2628e+00, 1.5445e+00,
6.5242e+00, 4.0004e+00, 2.7856e+00, 1.5080e+00, 4.0918e+00],
[ 2.0015e+00, -1.7211e+00, -2.6718e+00, 1.9970e+00, 3.6355e+00,
4.8944e-01, 2.0002e+00, 3.0038e+00, -7.9397e+00, 3.9988e+00],
[ 2.9915e+00, 1.0068e+01, 4.0126e+00, 5.0140e+00, 2.1576e-01,
6.2162e+00, 3.6692e+00, 9.8564e-01, 9.0477e+00, 2.9998e+00],
[ 3.1704e+00, 3.0173e+00, 6.3333e-01, 6.8363e+00, 2.2725e+00,
4.9317e+00, 4.7652e+00, 5.0099e+00, 2.2206e+00, 4.9007e+00],
[ 7.5330e-01, 1.4262e+00, 3.0599e+00, 8.5747e-01, 4.0712e+00,
1.1776e+00, 3.8169e+00, 3.8298e+00, 3.8094e+00, 4.5052e+00],
[ 4.0428e+00, -2.0650e+01, 2.4385e+00, 4.9425e+00, 1.9279e+00,
-1.5739e+00, 1.0876e+00, 1.0297e+01, 9.8782e-01, -4.4288e+00],
[ 1.0773e+00, 2.2150e+00, 2.0096e+00, -2.4710e+00, 4.9950e+00,
-5.0779e-01, 1.9650e+00, 1.0100e+00, -1.8899e+00, 4.0205e+00],
[ 3.4996e+00, 4.5594e+00, -2.5208e+00, 3.3402e+00, 2.6908e+00,
3.0928e+00, 2.0393e+00, 8.1359e-01, -7.3751e+00, 4.3822e+00],
[-4.8442e+00, 3.4518e+00, 9.9972e-01, -1.2881e+00, -3.5755e-01,
-5.3701e-01, 1.9997e+00, 4.9337e-01, 7.5069e+00, 3.0513e+00],
[ 1.0026e+00, 7.6105e+00, -7.1104e+00, 3.9509e+00, 4.9733e+00,
3.0444e+00, 4.5575e+00, 2.0356e+00, -1.3359e+01, 1.0958e+01],
[ 1.9708e+00, 2.2459e+01, 5.0115e+00, 1.0357e+00, 3.0386e+00,
7.0322e+00, 4.9501e+00, -3.0939e+00, 8.5776e+00, 8.3022e+00],
[ 3.0000e+00, 5.0024e+00, -1.4123e+00, 6.3346e+00, -9.3349e-01,
4.9944e+00, 2.0182e+00, 1.1012e+00, -1.8768e-01, 1.9884e+00],
[ 4.4226e+00, 2.0069e+00, 5.0098e+00, -8.1680e-01, 5.0611e+00,
1.2995e+00, 1.9966e+00, 1.6425e+00, 9.9608e-01, 1.9906e+00],
[ 1.1858e+01, 3.1806e+00, 1.9399e+00, 2.0242e+00, 9.7272e+00,
3.8295e+00, 2.5867e+00, 2.2628e+00, -1.4185e+01, 4.8150e+00],
[ 3.0139e+00, 4.0015e+00, 4.9996e+00, 2.0000e+00, 5.0004e+00,
3.0423e+00, 4.6969e+00, 4.0977e+00, 5.4217e+00, 4.9981e+00],
[ 3.1070e+00, 1.5608e+00, 4.4961e+00, 2.4194e+00, 3.9092e+00,
2.6510e+00, 3.8237e+00, 4.1656e+00, 5.0091e+00, 3.2677e+00],
[ 3.6123e+00, 5.0130e+00, 1.2513e+00, 5.0780e+00, -6.2890e-01,
4.7738e+00, 1.7050e+00, 7.5555e-01, 2.9369e+00, 7.0034e-01],
[ 1.9988e+00, -1.0755e+00, 2.1010e+00, 4.8154e-01, 5.0034e+00,
4.1238e-01, 3.1198e+00, 4.0017e+00, -6.4504e-01, 3.9979e+00],
[-2.4511e+00, 3.0049e+00, 4.0049e+00, 2.0168e+00, -1.7567e+00,
1.9773e+00, 2.5612e+00, 1.9963e+00, 1.3746e+01, 7.7954e-01],
[ 1.2900e+00, 4.8906e+00, 1.2285e+01, 2.0178e+00, 2.0017e+00,
4.0381e+00, 5.2765e+00, 4.7427e+00, 2.3263e+01, 1.4459e+00],
[ 3.4409e+00, 1.6323e+00, 2.1447e+00, 7.1434e+00, 2.1349e+00,
4.9995e+00, 4.9455e+00, 5.8905e+00, 5.0000e+00, 4.0006e+00],
[-8.3402e+00, 1.6671e+00, -9.7906e+00, 7.5550e+00, -7.1860e+00,
1.9594e+00, 2.1880e+00, 1.9448e+00, 1.9841e+00, 3.9175e+00],
[ 6.1855e+00, 1.0277e+00, 6.0344e+00, 2.0010e+00, 5.0009e+00,
2.9985e+00, 3.0008e+00, 3.6213e+00, 3.0914e+00, 1.7165e+00],
[ 9.9264e-01, 2.2125e+00, 3.9075e+00, 1.8425e+00, 4.0571e+00,
2.0497e+00, 4.5854e+00, 4.5379e+00, 6.0676e+00, 4.9087e+00],
[ 4.1464e+00, 1.4062e+01, 2.4284e+00, 4.7693e+00, 3.6919e+00,
6.9404e+00, 5.3224e+00, 1.0755e+00, 2.9501e+00, 7.6420e+00],
[ 3.7800e+00, 1.9199e+01, 1.3223e+00, 4.2259e+00, -1.6001e+00,
7.9395e+00, 1.7065e+00, -4.9606e+00, 3.8810e+00, 3.1246e+00],
[ 4.0234e+00, 2.0156e+00, 9.3090e-01, 2.9352e+00, 1.4458e+00,
2.8364e+00, 1.1944e+00, 1.0778e+00, -1.3489e+00, 8.8513e-01],
[ 2.9993e-01, 1.0106e+00, 2.8743e+00, 1.0596e+00, 2.1897e+00,
1.1528e+00, 2.7986e+00, 2.9140e+00, 5.0939e+00, 2.6367e+00],
[-2.7522e+00, 1.0011e+00, 9.9939e-01, 2.1638e+00, 9.9803e-01,
9.9864e-01, 3.8559e+00, 4.0029e+00, 6.9205e+00, 4.3317e+00],
[ 4.8687e+00, -1.5672e-01, 1.6703e+00, 7.5356e+00, 9.9502e-01,
4.9990e+00, 3.3880e+00, 5.0098e+00, 2.9938e+00, 1.4576e+00],
[ 3.0192e+00, 3.9819e+00, -5.9198e+00, 3.9832e+00, 2.9753e+00,
2.7502e+00, 2.0750e+00, 9.5908e-01, -1.2959e+01, 5.9251e+00],
[ 1.0148e+00, 9.7130e-01, 1.9768e+00, 2.0122e+00, 6.6718e-01,
1.7201e+00, 1.7755e+00, 1.9466e+00, 4.0138e+00, 1.0400e+00],
[ 5.3930e+00, 4.0298e+00, 2.0109e+00, 2.7548e+00, 2.9962e+00,
3.5641e+00, 1.8677e+00, 1.0778e+00, -1.7539e+00, 2.0255e+00],
[ 1.9953e+00, 1.2843e+01, 5.0056e+00, 2.5516e+00, 2.9983e+00,
5.4422e+00, 5.0090e+00, 9.8768e-01, 8.9960e+00, 6.2394e+00],
[ 2.8324e+00, 2.8434e+00, 1.9018e+00, 3.8000e+00, 3.1922e+00,
3.3458e+00, 3.8883e+00, 3.7678e+00, 2.0126e+00, 4.2686e+00],
[ 2.0652e+00, 8.7485e+00, 4.9003e+00, -7.1472e-02, 5.0456e+00,
2.9448e+00, 4.2957e+00, 1.4602e+00, 5.0622e+00, 6.0046e+00],
[ 1.2368e+00, -8.0124e-01, 4.5406e-01, 5.8132e-01, 5.0025e+00,
1.8401e-01, 3.2316e+00, 3.9075e+00, -2.7376e+00, 4.9980e+00],
[ 4.4707e+00, 1.3100e+01, 3.8913e+00, 5.0670e-01, 3.1494e+00,
4.7096e+00, 2.2617e+00, -2.3287e+00, 1.8589e+00, 3.9039e+00],
[ 9.9929e-01, 1.5253e+01, 2.1834e-03, 2.6820e+00, 2.0005e+00,
5.3245e+00, 3.9995e+00, -1.3442e+00, 1.4326e+00, 7.5636e+00]],
grad_fn=<MmBackward0>)
# Now use matrix factorization to predict the ratings
import torch
import torch.nn as nn
import torch.nn.functional as F
# Create a class for the model
class MatrixFactorization(nn.Module):
def __init__(self, n_users, n_movies, n_factors=20):
super().__init__()
self.user_factors = nn.Embedding(n_users, n_factors)
self.movie_factors = nn.Embedding(n_movies, n_factors)
def forward(self, user, movie):
return (self.user_factors(user) * self.movie_factors(movie)).sum(1) model = MatrixFactorization(n_users, n_movies, 2)
modelMatrixFactorization(
(user_factors): Embedding(100, 2)
(movie_factors): Embedding(10, 2)
)
model(torch.tensor([0]), torch.tensor([2]))tensor([-0.0271], grad_fn=<SumBackward1>)
A[0, 2]tensor(5.)
type(A)torch.Tensor
mask = ~torch.isnan(A)
# Get the indices of the non-NaN values
i, j = torch.where(mask)
# Get the values of the non-NaN values
v = A[mask]
# Store in PyTorch tensors
users = i.to(torch.int64)
movies = j.to(torch.int64)
ratings = v.to(torch.float32)pd.DataFrame({'user': users, 'movie': movies, 'rating': ratings})| user | movie | rating | |
|---|---|---|---|
| 0 | 0 | 1 | 5.0 |
| 1 | 0 | 2 | 5.0 |
| 2 | 0 | 4 | 1.0 |
| 3 | 0 | 5 | 1.0 |
| 4 | 0 | 6 | 4.0 |
| ... | ... | ... | ... |
| 512 | 98 | 8 | 2.0 |
| 513 | 98 | 9 | 4.0 |
| 514 | 99 | 0 | 1.0 |
| 515 | 99 | 4 | 2.0 |
| 516 | 99 | 6 | 4.0 |
517 rows × 3 columns
# Fit the Matrix Factorization model
model = MatrixFactorization(n_users, n_movies, 4)
optimizer = optim.Adam(model.parameters(), lr=0.01)
for i in range(1000):
# Compute the loss
pred = model(users, movies)
loss = F.mse_loss(pred, ratings)
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
# Print the loss
if i % 100 == 0:
print(loss.item())14.604362487792969
4.332712650299072
1.0960761308670044
0.6966323852539062
0.5388827919960022
0.45243579149246216
0.4012693464756012
0.3728969395160675
0.35568001866340637
0.34289655089378357
model(users, movies)tensor([3.5693, 4.5338, 2.6934, 1.8316, 4.8915, 2.0194, 2.7778, 1.8601, 2.1124,
1.1378, 2.9079, 5.0470, 0.9911, 1.9791, 1.0050, 3.9618, 2.0085, 2.0034,
4.0113, 1.9218, 2.9801, 1.0432, 3.8993, 4.1292, 1.9357, 3.8285, 3.5266,
1.3640, 2.3989, 2.4166, 0.8559, 3.3685, 4.4493, 3.4018, 1.4722, 4.8378,
4.6684, 4.4473, 4.3097, 2.0022, 5.0147, 5.0113, 1.9599, 2.8305, 1.4493,
1.6750, 0.9520, 3.8460, 5.1279, 4.7453, 2.1484, 1.8009, 3.2104, 4.2068,
4.5473, 2.9229, 1.1817, 3.1108, 3.2157, 1.2238, 3.6272, 3.9029, 3.2554,
0.9945, 3.0062, 5.0030, 4.1144, 1.6314, 3.7945, 3.3091, 4.0727, 3.6212,
2.4359, 3.5707, 1.2826, 0.9663, 4.9973, 3.0163, 2.9916, 1.0014, 3.1734,
3.8712, 4.3364, 3.7119, 4.5313, 2.3875, 4.0274, 4.7121, 4.3851, 2.8072,
3.2066, 3.9684, 0.9307, 1.4160, 2.7484, 2.8771, 1.2753, 2.5825, 4.9857,
2.0109, 1.0080, 2.2618, 2.7936, 2.5859, 3.0972, 3.1443, 3.8655, 3.2023,
2.0061, 1.9866, 4.0068, 2.7861, 4.2803, 4.5452, 3.8177, 2.9951, 5.6040,
1.9013, 2.4498, 1.9155, 4.3239, 3.2993, 1.1275, 4.1460, 4.3619, 4.5913,
1.1365, 3.0865, 6.0992, 3.9207, 1.8325, 3.8526, 2.0000, 2.0000, 1.0255,
5.0484, 1.9858, 1.9716, 2.0854, 2.7726, 3.5053, 1.7365, 3.2607, 4.7006,
1.8830, 0.3378, 2.6222, 0.6491, 3.1144, 3.8404, 2.0632, 4.3311, 4.6138,
1.6119, 3.0444, 2.4498, 2.7860, 4.0832, 3.2681, 5.0069, 4.8745, 3.6048,
3.5065, 1.8871, 2.9108, 0.9982, 0.6787, 1.4787, 5.6890, 1.5117, 4.0205,
4.6624, 4.1051, 4.5437, 3.5410, 1.8761, 4.6653, 5.2577, 5.0311, 4.9531,
4.0022, 4.7229, 4.0406, 2.5056, 3.2014, 2.9312, 1.3051, 1.3258, 4.9233,
2.0871, 1.8761, 4.1304, 3.9279, 4.5547, 1.6639, 4.3163, 1.8149, 4.7431,
3.4398, 3.5483, 3.2541, 2.8282, 1.5230, 2.8214, 4.9214, 3.5371, 4.4522,
4.7928, 2.4932, 4.8619, 4.7693, 2.3121, 2.3021, 4.6609, 4.7478, 1.7768,
5.3910, 2.8929, 2.4920, 4.1654, 1.8382, 3.4456, 3.7108, 2.3989, 1.0480,
4.2870, 5.2730, 3.8976, 3.2956, 3.1010, 0.9666, 3.0227, 2.9491, 3.2218,
1.4033, 1.7050, 0.5457, 3.1223, 2.1794, 3.9939, 5.0190, 2.9944, 4.5670,
5.4131, 2.4720, 4.0239, 4.5689, 0.9386, 1.1827, 2.1553, 1.4232, 0.4127,
1.7502, 1.5601, 2.0953, 2.3530, 2.7085, 2.7157, 2.7900, 4.6793, 1.1711,
3.4493, 1.7701, 2.1576, 0.8294, 1.9908, 2.1975, 3.6396, 3.1952, 4.9882,
4.0091, 2.9998, 1.9985, 3.0063, 3.7149, 4.2018, 1.1992, 2.2550, 3.5132,
1.3794, 2.4434, 2.0220, 0.6691, 1.9927, 3.4666, 2.8342, 1.0393, 4.1834,
4.0819, 0.8652, 4.8585, 2.0225, 2.9662, 3.9794, 1.0412, 3.0180, 1.0117,
3.1728, 2.6027, 2.5184, 2.5120, 3.3483, 3.0602, 3.2139, 1.9144, 1.4899,
2.8253, 2.9160, 1.6410, 3.8484, 2.9637, 2.1272, 2.0490, 4.4915, 1.6853,
5.0641, 3.9315, 1.8812, 2.1028, 2.1203, 2.8898, 4.0215, 3.3032, 3.6131,
4.8380, 0.7513, 3.4182, 3.0527, 3.6200, 1.3040, 1.1657, 4.2705, 4.8569,
2.7037, 0.9142, 1.5261, 3.3970, 0.6980, 3.1131, 3.8348, 3.4478, 4.9488,
4.0077, 4.9673, 2.0816, 0.9995, 3.4965, 3.2899, 1.0624, 1.4977, 4.3810,
4.2225, 3.1540, 4.0191, 3.4060, 2.0160, 1.2814, 2.9608, 1.1513, 1.9530,
0.9872, 3.9914, 4.9865, 3.0477, 1.9701, 2.5473, 4.6173, 0.5252, 3.4813,
4.7750, 3.5971, 4.1422, 5.2524, 2.0691, 2.0039, 2.1420, 5.0156, 1.8481,
0.7781, 2.2183, 2.6443, 2.2889, 4.3373, 2.2449, 4.5594, 3.8624, 5.2333,
2.1537, 4.7373, 5.0298, 2.1322, 2.7047, 4.2308, 2.4907, 2.3259, 4.7678,
5.0266, 5.2597, 4.7697, 0.9911, 1.0440, 2.7818, 1.1585, 2.1967, 4.4422,
3.4151, 4.8498, 2.8759, 4.0001, 1.8111, 2.3611, 1.9162, 4.3893, 2.4916,
2.0815, 3.9944, 5.5013, 4.6856, 4.9904, 4.9865, 4.0133, 1.7178, 2.2895,
1.9270, 2.6042, 3.5529, 2.0969, 4.9558, 2.9138, 3.0347, 1.6650, 2.2754,
3.7474, 0.9570, 2.3175, 4.1652, 4.4857, 5.1715, 5.4181, 3.7923, 4.0080,
3.5960, 2.1456, 2.9088, 3.2796, 1.2589, 4.6664, 2.3447, 3.6577, 2.8478,
2.8383, 2.9852, 0.9179, 3.0814, 1.2291, 0.7621, 0.8185, 2.3877, 1.1364,
2.8636, 3.2811, 4.5790, 0.7829, 1.2192, 0.8060, 1.2341, 3.9681, 1.1422,
4.9521, 5.1383, 2.8294, 3.6377, 3.8734, 3.5638, 2.9749, 1.5051, 1.4258,
0.5846, 1.0630, 2.7671, 1.7912, 2.8734, 1.8916, 3.9858, 2.5439, 2.6786,
1.6704, 1.3576, 1.7701, 2.0514, 5.1632, 2.6842, 4.8280, 1.1577, 2.9727,
3.3256, 2.7693, 2.6430, 4.0140, 2.2755, 2.1973, 4.6792, 5.3522, 2.8058,
4.9664, 5.2983, 4.7749, 3.8326, 2.4944, 4.2989, 2.8784, 2.0605, 2.8428,
3.1081, 1.0411, 1.9078, 3.9594], grad_fn=<SumBackward1>)
# Now, let's predict the ratings for our df dataframe
A = torch.from_numpy(df.values)
A.shapetorch.Size([45, 10])
mask = ~torch.isnan(A)
# Get the indices of the non-NaN values
i, j = torch.where(mask)
# Get the values of the non-NaN values
v = A[mask]
# Store in PyTorch tensors
users = i.to(torch.int64)
movies = j.to(torch.int64)
ratings = v.to(torch.float32)pd.DataFrame({'user': users, 'movie': movies, 'rating': ratings})| user | movie | rating | |
|---|---|---|---|
| 0 | 0 | 0 | 4.0 |
| 1 | 0 | 1 | 5.0 |
| 2 | 0 | 2 | 4.0 |
| 3 | 0 | 3 | 4.0 |
| 4 | 0 | 4 | 5.0 |
| ... | ... | ... | ... |
| 371 | 44 | 3 | 4.0 |
| 372 | 44 | 4 | 4.0 |
| 373 | 44 | 5 | 4.0 |
| 374 | 44 | 6 | 4.0 |
| 375 | 44 | 7 | 5.0 |
376 rows × 3 columns
# Fit the Matrix Factorization model
n_users = A.shape[0]
n_movies = A.shape[1]
model = MatrixFactorization(n_users, n_movies, 4)
optimizer = optim.Adam(model.parameters(), lr=0.01)
for i in range(1000):
# Compute the loss
pred = model(users, movies)
loss = F.mse_loss(pred, ratings)
# Zero the gradients
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update the parameters
optimizer.step()
# Print the loss
if i % 100 == 0:
print(loss.item())19.889324188232422
3.1148574352264404
0.6727441549301147
0.5543633103370667
0.5081750750541687
0.4629250764846802
0.4147825837135315
0.36878159642219543
0.32987719774246216
0.29975879192352295
# Now, let us predict the ratings for any user and movie from df for which we already have the ratings
username = 'Dhruv'
movie = 'The Dark Knight'
# Get the user and movie indices
user_idx = df.index.get_loc(username)
movie_idx = df.columns.get_loc(movie)
# Predict the rating
pred = model(torch.tensor([user_idx]), torch.tensor([movie_idx]))
pred.item(), df.loc[username, movie](5.259384632110596, 5.0)
df.loc[username]Sholay NaN
Swades (We The People) NaN
The Matrix (I) 5.0
Interstellar 5.0
Dangal 3.0
Taare Zameen Par NaN
Shawshank Redemption 5.0
The Dark Knight 5.0
Notting Hill 4.0
Uri: The Surgical Strike 5.0
Name: Dhruv, dtype: float64
# Now, let us predict the ratings for any user and movie from df for which we do not have the ratings
username = 'Dhruv'
movie = 'Sholay'
# Get the user and movie indices
user_idx = df.index.get_loc(username)
movie_idx = df.columns.get_loc(movie)
# Predict the rating
pred = model(torch.tensor([user_idx]), torch.tensor([movie_idx]))
pred, df.loc[username, movie](tensor([3.7885], grad_fn=<SumBackward1>), nan)
# Complete the matrix
with torch.no_grad():
completed_matrix = pd.DataFrame(model.user_factors.weight @ model.movie_factors.weight.t(), index=df.index, columns=df.columns)
# round to nearest integer
completed_matrix = completed_matrix.round()completed_matrix.head()| Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
|---|---|---|---|---|---|---|---|---|---|---|
| Your name | ||||||||||
| Nipun | 5.0 | 4.0 | 4.0 | 5.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 |
| Gautam Vashishtha | 3.0 | 3.0 | 4.0 | 4.0 | 2.0 | 3.0 | 4.0 | 5.0 | 4.0 | 3.0 |
| Eshan Gujarathi | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 4.0 |
| Sai Krishna Avula | 4.0 | 4.0 | 3.0 | 4.0 | 4.0 | 6.0 | 4.0 | 3.0 | 3.0 | 4.0 |
| Ankit Yadav | 3.0 | 2.0 | 3.0 | 4.0 | 3.0 | 5.0 | 4.0 | 3.0 | 3.0 | 4.0 |
df.head()| Sholay | Swades (We The People) | The Matrix (I) | Interstellar | Dangal | Taare Zameen Par | Shawshank Redemption | The Dark Knight | Notting Hill | Uri: The Surgical Strike | |
|---|---|---|---|---|---|---|---|---|---|---|
| Your name | ||||||||||
| Nipun | 4.0 | 5.0 | 4.0 | 4.0 | 5.0 | 5.0 | 4.0 | 5.0 | 4.0 | 5.0 |
| Gautam Vashishtha | 3.0 | 4.0 | 4.0 | 5.0 | 3.0 | 1.0 | 5.0 | 5.0 | 4.0 | 3.0 |
| Eshan Gujarathi | 4.0 | NaN | 5.0 | 5.0 | 4.0 | 5.0 | 5.0 | 5.0 | NaN | 4.0 |
| Sai Krishna Avula | 5.0 | 3.0 | 3.0 | 4.0 | 4.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |
| Ankit Yadav | 3.0 | 3.0 | 2.0 | 5.0 | 2.0 | 5.0 | 5.0 | 3.0 | 3.0 | 4.0 |