Movie Recommendation using KNN and Matrix Factorization

Movie Recommendation using KNN and Matrix Factorization
Author

Nipun Batra

Published

April 14, 2023

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# Generate some toy user and movie data

# Number of users
n_users = 100

# Number of movies
n_movies = 10

# Number of ratings
n_ratings = 1000

# Generate random user ids
user_ids = np.random.randint(0, n_users, n_ratings)

# Generate random movie ids
movie_ids = np.random.randint(0, n_movies, n_ratings)

# Generate random ratings
ratings = np.random.randint(1, 6, n_ratings)

# Create a dataframe with the data
df = pd.DataFrame({'user_id': user_ids, 'movie_id': movie_ids, 'rating': ratings})

# We should not have any duplicate ratings for the same user and movie
# Drop any rows that have duplicate user_id and movie_id pairs
df = df.drop_duplicates(['user_id', 'movie_id'])
df
user_id movie_id rating
0 66 7 2
1 54 7 4
2 38 5 3
3 56 6 1
4 4 0 4
... ... ... ...
987 77 8 3
992 99 3 3
994 8 5 3
998 22 2 3
999 88 9 1

642 rows × 3 columns

# Create a user-item matrix

A = df.pivot(index='user_id', columns='movie_id', values='rating')
A
movie_id 0 1 2 3 4 5 6 7 8 9
user_id
0 3.0 4.0 NaN 5.0 5.0 1.0 2.0 5.0 2.0 4.0
1 4.0 2.0 NaN 1.0 NaN 3.0 3.0 5.0 3.0 1.0
2 3.0 NaN NaN 4.0 1.0 5.0 NaN 2.0 NaN 2.0
3 4.0 NaN 4.0 2.0 NaN 3.0 NaN 2.0 NaN 4.0
4 4.0 3.0 3.0 3.0 NaN NaN 4.0 NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ...
95 1.0 3.0 1.0 2.0 3.0 NaN 1.0 NaN 5.0 NaN
96 1.0 2.0 NaN NaN NaN NaN 3.0 2.0 2.0 5.0
97 3.0 1.0 4.0 NaN 3.0 1.0 NaN NaN 5.0 NaN
98 5.0 3.0 NaN NaN 2.0 NaN 1.0 3.0 4.0 3.0
99 1.0 5.0 2.0 3.0 NaN 2.0 4.0 3.0 3.0 NaN

100 rows × 10 columns

# Fill in the missing values with zeros
A = A.fillna(0)

A
movie_id 0 1 2 3 4 5 6 7 8 9
user_id
0 3.0 4.0 0.0 5.0 5.0 1.0 2.0 5.0 2.0 4.0
1 4.0 2.0 0.0 1.0 0.0 3.0 3.0 5.0 3.0 1.0
2 3.0 0.0 0.0 4.0 1.0 5.0 0.0 2.0 0.0 2.0
3 4.0 0.0 4.0 2.0 0.0 3.0 0.0 2.0 0.0 4.0
4 4.0 3.0 3.0 3.0 0.0 0.0 4.0 0.0 0.0 0.0
... ... ... ... ... ... ... ... ... ... ...
95 1.0 3.0 1.0 2.0 3.0 0.0 1.0 0.0 5.0 0.0
96 1.0 2.0 0.0 0.0 0.0 0.0 3.0 2.0 2.0 5.0
97 3.0 1.0 4.0 0.0 3.0 1.0 0.0 0.0 5.0 0.0
98 5.0 3.0 0.0 0.0 2.0 0.0 1.0 3.0 4.0 3.0
99 1.0 5.0 2.0 3.0 0.0 2.0 4.0 3.0 3.0 0.0

100 rows × 10 columns

# Cosine similarity between U1 and U2

# User 1
u1 = A.loc[0]

# User 2
u2 = A.loc[1]

# Compute the dot product
dot = np.dot(u1, u2)

# Compute the L2 norm
norm_u1 = np.linalg.norm(u1)
norm_u2 = np.linalg.norm(u2)

# Compute the cosine similarity
cos_sim = dot / (norm_u1 * norm_u2)
cos_sim
0.7174278379758501
# Calculate the cosine similarity between users
from sklearn.metrics.pairwise import cosine_similarity

sim_matrix = cosine_similarity(A)

pd.DataFrame(sim_matrix)
0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
0 1.000000 0.717428 0.663734 0.565794 0.547289 0.742468 0.539472 0.599021 0.748964 0.523832 ... 0.876593 0.593722 0.615664 0.722496 0.783190 0.657754 0.665375 0.446627 0.774667 0.703313
1 0.717428 1.000000 0.650769 0.591169 0.559964 0.585314 0.300491 0.255039 0.693395 0.500193 ... 0.609392 0.313893 0.510454 0.550459 0.624622 0.493197 0.644346 0.476288 0.802740 0.781611
2 0.663734 0.650769 1.000000 0.758954 0.406780 0.372046 0.416654 0.270593 0.746685 0.560180 ... 0.519274 0.322243 0.772529 0.253842 0.712485 0.257761 0.322830 0.283373 0.441886 0.459929
3 0.565794 0.591169 0.758954 1.000000 0.549030 0.540128 0.671775 0.572892 0.611794 0.489225 ... 0.424052 0.586110 0.456327 0.506719 0.715831 0.210494 0.506585 0.492312 0.551652 0.424052
4 0.547289 0.559964 0.406780 0.549030 1.000000 0.354329 0.304478 0.541185 0.821353 0.202287 ... 0.667638 0.527306 0.339913 0.435159 0.582943 0.478699 0.417780 0.450063 0.502836 0.741820
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
95 0.657754 0.493197 0.257761 0.210494 0.478699 0.519615 0.452602 0.571548 0.551553 0.405674 ... 0.564076 0.556890 0.604211 0.682793 0.661382 1.000000 0.412568 0.796715 0.678637 0.693008
96 0.665375 0.644346 0.322830 0.506585 0.417780 0.853538 0.359095 0.538977 0.351369 0.191776 ... 0.498686 0.640033 0.190421 0.690704 0.682163 0.412568 1.000000 0.280141 0.734105 0.581800
97 0.446627 0.476288 0.283373 0.492312 0.450063 0.400743 0.709211 0.532239 0.469979 0.566223 ... 0.262641 0.446564 0.501441 0.689500 0.637007 0.796715 0.280141 1.000000 0.659366 0.481508
98 0.774667 0.802740 0.441886 0.551652 0.502836 0.844146 0.446610 0.473016 0.617575 0.335738 ... 0.546861 0.500390 0.305585 0.673754 0.687116 0.678637 0.734105 0.659366 1.000000 0.600213
99 0.703313 0.781611 0.459929 0.424052 0.741820 0.527274 0.294579 0.578997 0.732042 0.435869 ... 0.818182 0.615435 0.581559 0.645439 0.623673 0.693008 0.581800 0.481508 0.600213 1.000000

100 rows × 100 columns

import seaborn as sns

sns.heatmap(sim_matrix, cmap='Greys')
<AxesSubplot:>

# Find the most similar users to user u 

def k_nearest_neighbors(A, u, k):
    """Find the k nearest neighbors for user u"""
    # Find the index of the user in the matrix
    u_index = A.index.get_loc(u)
    
    # Compute the similarity between the user and all other users
    sim_matrix = cosine_similarity(A)

    # Find the k most similar users
    k_nearest = np.argsort(sim_matrix[u_index])[::-1][1:k+1]
    
    # Return the user ids
    return A.index[k_nearest]
k_nearest_neighbors(A, 0, 5)
Int64Index([28, 46, 90, 32, 87], dtype='int64', name='user_id')
# Show matrix of movie ratings for u and k nearest neighbors

def show_neighbors(A, u, k):
    """Show the movie ratings for user u and k nearest neighbors"""
    # Get the user ids of the k nearest neighbors
    neighbors = k_nearest_neighbors(A, u, k)
    
    # Get the movie ratings for user u and the k nearest neighbors
    df = A.loc[[u] + list(neighbors)]
    
    # Return the dataframe
    return df
show_neighbors(A, 0, 5)
movie_id 0 1 2 3 4 5 6 7 8 9
user_id
0 3.0 4.0 0.0 5.0 5.0 1.0 2.0 5.0 2.0 4.0
28 5.0 0.0 0.0 5.0 4.0 2.0 1.0 4.0 0.0 5.0
46 3.0 0.0 2.0 5.0 5.0 1.0 0.0 4.0 2.0 2.0
90 1.0 5.0 1.0 5.0 2.0 0.0 2.0 4.0 0.0 1.0
32 3.0 2.0 2.0 5.0 4.0 5.0 3.0 5.0 0.0 3.0
87 1.0 0.0 0.0 5.0 4.0 0.0 3.0 4.0 4.0 2.0
# Rating for user u for movie 0 is: (4.0 + 3.0) / 2 = 3.5 (Discard 0s)

def predict_rating(A, u, m, k=5):
    """Predict the rating for user u for movie m"""
    # Get the user ids of the k nearest neighbors
    neighbors = k_nearest_neighbors(A, u, k)
    
    # Get the movie ratings for user u and the k nearest neighbors
    df = A.loc[[u] + list(neighbors)]
    
    # Get the ratings for movie m
    ratings = df[m]
    
    # Calculate the mean of the ratings
    mean = ratings[1:][ratings != 0].mean()
    
    # Return the mean
    return mean
predict_rating(A, 0, 5)
2.6666666666666665
# Now working with real data

# Load the data

df = pd.read_excel("mov-rec.xlsx")
df.head()
Timestamp Your name Sholay Swades (We The People) The Matrix (I) Interstellar Dangal Taare Zameen Par Shawshank Redemption The Dark Knight Notting Hill Uri: The Surgical Strike
0 2023-04-11 10:58:44.990 Nipun 4.0 5.0 4.0 4.0 5.0 5.0 4.0 5.0 4.0 5.0
1 2023-04-11 10:59:49.617 Gautam Vashishtha 3.0 4.0 4.0 5.0 3.0 1.0 5.0 5.0 4.0 3.0
2 2023-04-11 11:12:44.033 Eshan Gujarathi 4.0 NaN 5.0 5.0 4.0 5.0 5.0 5.0 NaN 4.0
3 2023-04-11 11:13:48.674 Sai Krishna Avula 5.0 3.0 3.0 4.0 4.0 5.0 5.0 3.0 3.0 4.0
4 2023-04-11 11:13:55.658 Ankit Yadav 3.0 3.0 2.0 5.0 2.0 5.0 5.0 3.0 3.0 4.0
# Discard the timestamp column

df = df.drop('Timestamp', axis=1)

# Make the "Your Name" column the index

df = df.set_index('Your name')
df
Sholay Swades (We The People) The Matrix (I) Interstellar Dangal Taare Zameen Par Shawshank Redemption The Dark Knight Notting Hill Uri: The Surgical Strike
Your name
Nipun 4.0 5.0 4.0 4.0 5.0 5.0 4.0 5.0 4.0 5.0
Gautam Vashishtha 3.0 4.0 4.0 5.0 3.0 1.0 5.0 5.0 4.0 3.0
Eshan Gujarathi 4.0 NaN 5.0 5.0 4.0 5.0 5.0 5.0 NaN 4.0
Sai Krishna Avula 5.0 3.0 3.0 4.0 4.0 5.0 5.0 3.0 3.0 4.0
Ankit Yadav 3.0 3.0 2.0 5.0 2.0 5.0 5.0 3.0 3.0 4.0
Dhruv NaN NaN 5.0 5.0 3.0 NaN 5.0 5.0 4.0 5.0
Saatvik Rao 4.0 3.0 4.0 5.0 2.0 2.0 4.0 5.0 3.0 5.0
Zeel B Patel 5.0 4.0 5.0 4.0 4.0 4.0 NaN 2.0 NaN 5.0
Neel 4.0 NaN 5.0 5.0 3.0 3.0 5.0 5.0 NaN 4.0
Sachin Jalan 4.0 NaN 5.0 5.0 3.0 4.0 4.0 5.0 NaN 3.0
Ayush Shrivastava 5.0 4.0 5.0 5.0 3.0 3.0 4.0 4.0 NaN 4.0
.... 4.0 4.0 NaN 4.0 4.0 4.0 NaN 4.0 NaN 4.0
Hari Hara Sudhan 4.0 3.0 5.0 4.0 4.0 5.0 4.0 5.0 3.0 5.0
Etikikota Hrushikesh NaN NaN 1.0 1.0 1.0 2.0 1.0 1.0 NaN NaN
Chirag 5.0 3.0 4.0 5.0 5.0 2.0 3.0 4.0 2.0 5.0
Aaryan Darad 5.0 4.0 4.0 5.0 3.0 4.0 3.0 3.0 NaN 4.0
Hetvi Patel 4.0 3.0 2.0 5.0 4.0 4.0 5.0 3.0 3.0 5.0
Kalash Kankaria 4.0 NaN 4.0 5.0 3.0 4.0 NaN NaN NaN 3.0
Rachit Verma NaN NaN 4.0 5.0 3.0 5.0 5.0 5.0 NaN 4.0
shriraj 3.0 2.0 5.0 4.0 2.0 3.0 4.0 5.0 4.0 5.0
Bhavini Korthi NaN NaN NaN NaN 4.0 5.0 NaN NaN NaN 5.0
Hitarth Gandhi 3.0 NaN 4.0 5.0 3.0 4.0 5.0 5.0 NaN NaN
Radhika Joglekar 3.0 3.0 3.0 4.0 5.0 5.0 2.0 1.0 2.0 5.0
Medhansh Singh 4.0 3.0 5.0 5.0 3.0 5.0 5.0 5.0 5.0 5.0
Arun Mani NaN NaN 4.0 5.0 4.0 5.0 5.0 5.0 4.0 NaN
Satyam 3.0 5.0 5.0 5.0 4.0 3.0 5.0 5.0 NaN 5.0
Karan Kumar 4.0 3.0 5.0 4.0 5.0 5.0 3.0 5.0 5.0 4.0
R Yeeshu Dhurandhar 5.0 NaN 4.0 5.0 4.0 4.0 5.0 5.0 NaN NaN
Satyam Gupta 5.0 5.0 NaN 5.0 4.0 4.0 NaN 4.0 NaN 2.0
rushali NaN NaN NaN 5.0 4.0 3.0 NaN NaN NaN NaN
shridhar 5.0 4.0 5.0 5.0 4.0 4.0 5.0 4.0 3.0 3.0
Tanvi Jain 4.0 3.0 NaN NaN 4.0 5.0 NaN NaN NaN 5.0
Manish Prabhubhai Salvi 4.0 5.0 4.0 5.0 5.0 5.0 NaN 4.0 NaN 5.0
Varun Barala 5.0 5.0 5.0 5.0 4.0 4.0 5.0 4.0 3.0 4.0
Kevin Shah 3.0 4.0 5.0 5.0 4.0 5.0 5.0 4.0 3.0 5.0
Inderjeet 4.0 NaN 4.0 5.0 4.0 3.0 5.0 5.0 NaN 3.0
Gangaram Siddam 4.0 4.0 3.0 3.0 5.0 5.0 4.0 4.0 3.0 5.0
Aditi 4.0 4.0 NaN 5.0 1.0 3.0 5.0 4.0 NaN 4.0
Madhuri Awachar 5.0 4.0 5.0 4.0 5.0 3.0 5.0 5.0 4.0 5.0
Anupam 5.0 5.0 NaN 5.0 5.0 5.0 NaN NaN NaN 5.0
Jinay 3.0 1.0 4.0 3.0 4.0 3.0 5.0 5.0 4.0 3.0
Shrutimoy 5.0 5.0 5.0 5.0 4.0 5.0 5.0 5.0 NaN 2.0
Aadesh Desai 4.0 4.0 3.0 5.0 3.0 3.0 5.0 5.0 4.0 5.0
Dhairya 5.0 4.0 4.0 5.0 3.0 5.0 NaN 4.0 NaN 4.0
Rahul C 3.0 3.0 4.0 4.0 4.0 4.0 4.0 5.0 NaN NaN
df.index
Index(['Nipun', 'Gautam Vashishtha', 'Eshan Gujarathi', 'Sai Krishna Avula',
       'Ankit Yadav ', 'Dhruv', 'Saatvik Rao ', 'Zeel B Patel', 'Neel ',
       'Sachin Jalan ', 'Ayush Shrivastava', '....', 'Hari Hara Sudhan',
       'Etikikota Hrushikesh', 'Chirag', 'Aaryan Darad', 'Hetvi Patel',
       'Kalash Kankaria', 'Rachit Verma', 'shriraj', 'Bhavini Korthi ',
       'Hitarth Gandhi ', 'Radhika Joglekar ', 'Medhansh Singh', 'Arun Mani',
       'Satyam ', 'Karan Kumar ', 'R Yeeshu Dhurandhar', 'Satyam Gupta',
       'rushali', 'shridhar', 'Tanvi Jain ', 'Manish Prabhubhai Salvi ',
       'Varun Barala', 'Kevin Shah ', 'Inderjeet', 'Gangaram Siddam ', 'Aditi',
       'Madhuri Awachar', 'Anupam', 'Jinay', 'Shrutimoy', 'Aadesh Desai',
       'Dhairya', 'Rahul C'],
      dtype='object', name='Your name')
# Get index for user and movie
user = 'Rahul C'

print(user in df.index)

# Get the movie ratings for user
user_ratings = df.loc[user]
user_ratings
True
Sholay                      3.0
Swades (We The People)      3.0
The Matrix (I)              4.0
Interstellar                4.0
Dangal                      4.0
Taare Zameen Par            4.0
Shawshank Redemption        4.0
The Dark Knight             5.0
Notting Hill                NaN
Uri: The Surgical Strike    NaN
Name: Rahul C, dtype: float64
df_copy = df.copy()
df_copy.fillna(0, inplace=True)
show_neighbors(df_copy, user, 5)
Sholay Swades (We The People) The Matrix (I) Interstellar Dangal Taare Zameen Par Shawshank Redemption The Dark Knight Notting Hill Uri: The Surgical Strike
Your name
Rahul C 3.0 3.0 4.0 4.0 4.0 4.0 4.0 5.0 0.0 0.0
Shrutimoy 5.0 5.0 5.0 5.0 4.0 5.0 5.0 5.0 0.0 2.0
Hitarth Gandhi 3.0 0.0 4.0 5.0 3.0 4.0 5.0 5.0 0.0 0.0
R Yeeshu Dhurandhar 5.0 0.0 4.0 5.0 4.0 4.0 5.0 5.0 0.0 0.0
shridhar 5.0 4.0 5.0 5.0 4.0 4.0 5.0 4.0 3.0 3.0
Sachin Jalan 4.0 0.0 5.0 5.0 3.0 4.0 4.0 5.0 0.0 3.0
df.describe()
Sholay Swades (We The People) The Matrix (I) Interstellar Dangal Taare Zameen Par Shawshank Redemption The Dark Knight Notting Hill Uri: The Surgical Strike
count 39.000000 32.000000 38.000000 43.000000 45.000000 44.000000 35.000000 40.000000 21.000000 39.000000
mean 4.102564 3.718750 4.131579 4.581395 3.644444 3.977273 4.400000 4.250000 3.476190 4.230769
std 0.753758 0.958304 0.991070 0.793802 1.003529 1.067242 0.976187 1.080123 0.813575 0.902089
min 3.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 2.000000 2.000000
25% 4.000000 3.000000 4.000000 4.000000 3.000000 3.000000 4.000000 4.000000 3.000000 4.000000
50% 4.000000 4.000000 4.000000 5.000000 4.000000 4.000000 5.000000 5.000000 3.000000 4.000000
75% 5.000000 4.000000 5.000000 5.000000 4.000000 5.000000 5.000000 5.000000 4.000000 5.000000
max 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000 5.000000
# Predict the rating for user u for movie m

predict_rating(df_copy, user, 'The Dark Knight')
4.8
predict_rating(df_copy, user, 'Sholay')
4.4
# Generic Matrix Factorization (without missing values)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


# A is a matrix of size (n_users, n_movies) randomly generated values between 1 and 5
A = torch.randint(1, 6, (n_users, n_movies), dtype=torch.float)
A
tensor([[1., 5., 4., 3., 4., 3., 1., 3., 3., 4.],
        [3., 3., 4., 4., 4., 2., 5., 2., 2., 4.],
        [2., 5., 5., 3., 3., 1., 4., 5., 2., 5.],
        [4., 2., 1., 5., 1., 5., 1., 1., 5., 1.],
        [4., 3., 5., 3., 4., 1., 2., 4., 1., 5.],
        [1., 2., 2., 1., 2., 4., 1., 1., 3., 3.],
        [4., 2., 2., 3., 5., 3., 3., 2., 5., 2.],
        [2., 2., 1., 1., 2., 4., 5., 1., 4., 3.],
        [1., 3., 2., 1., 3., 4., 1., 5., 5., 1.],
        [4., 4., 2., 4., 1., 4., 2., 1., 4., 4.],
        [1., 3., 1., 2., 3., 1., 3., 2., 4., 4.],
        [5., 3., 5., 1., 4., 5., 4., 2., 1., 5.],
        [5., 1., 5., 3., 4., 4., 4., 2., 2., 2.],
        [1., 1., 2., 3., 5., 4., 3., 1., 1., 1.],
        [3., 2., 5., 4., 3., 5., 2., 4., 4., 3.],
        [2., 2., 4., 1., 2., 3., 2., 4., 3., 5.],
        [4., 2., 4., 5., 2., 2., 1., 2., 1., 1.],
        [4., 3., 2., 1., 1., 1., 5., 2., 1., 1.],
        [2., 1., 2., 1., 1., 1., 5., 5., 1., 4.],
        [1., 3., 5., 1., 5., 5., 5., 5., 2., 2.],
        [1., 2., 2., 5., 1., 1., 5., 4., 2., 5.],
        [5., 1., 2., 2., 1., 5., 2., 3., 5., 2.],
        [2., 2., 1., 5., 5., 5., 5., 3., 4., 3.],
        [1., 1., 4., 1., 2., 5., 5., 5., 2., 1.],
        [5., 3., 3., 1., 2., 3., 3., 2., 2., 3.],
        [2., 4., 5., 1., 2., 3., 5., 3., 1., 5.],
        [2., 1., 1., 1., 2., 2., 5., 5., 2., 1.],
        [3., 1., 1., 3., 3., 3., 2., 1., 3., 2.],
        [4., 4., 4., 5., 1., 3., 4., 3., 2., 2.],
        [4., 4., 4., 5., 2., 5., 2., 1., 4., 3.],
        [3., 3., 5., 1., 4., 3., 3., 3., 3., 2.],
        [2., 4., 2., 3., 4., 2., 4., 3., 3., 3.],
        [5., 4., 1., 5., 4., 3., 5., 1., 3., 4.],
        [3., 5., 3., 2., 4., 2., 5., 1., 2., 5.],
        [2., 5., 4., 1., 1., 5., 1., 5., 2., 4.],
        [1., 3., 3., 4., 4., 5., 2., 5., 3., 5.],
        [3., 3., 5., 3., 4., 1., 3., 1., 1., 3.],
        [2., 4., 1., 3., 5., 1., 5., 2., 4., 1.],
        [5., 2., 2., 3., 1., 4., 5., 5., 4., 2.],
        [5., 3., 5., 1., 5., 4., 3., 1., 1., 3.],
        [4., 2., 2., 2., 2., 1., 2., 1., 1., 3.],
        [2., 2., 1., 4., 1., 4., 5., 2., 5., 1.],
        [3., 4., 2., 1., 3., 1., 2., 5., 3., 5.],
        [1., 3., 3., 5., 3., 2., 1., 2., 5., 1.],
        [3., 4., 2., 4., 2., 3., 4., 1., 1., 1.],
        [1., 1., 5., 1., 3., 2., 5., 5., 5., 4.],
        [4., 4., 4., 4., 4., 3., 1., 4., 1., 1.],
        [3., 4., 5., 4., 1., 5., 2., 3., 1., 3.],
        [5., 2., 5., 5., 2., 4., 5., 4., 4., 5.],
        [3., 5., 5., 4., 1., 5., 1., 2., 5., 1.],
        [2., 4., 3., 5., 4., 5., 2., 5., 3., 3.],
        [3., 2., 3., 1., 1., 4., 1., 1., 2., 1.],
        [4., 1., 2., 4., 4., 3., 2., 2., 2., 4.],
        [2., 2., 3., 2., 3., 2., 2., 5., 4., 3.],
        [4., 2., 3., 4., 4., 1., 1., 4., 2., 3.],
        [2., 2., 3., 2., 2., 1., 4., 3., 5., 4.],
        [5., 5., 5., 4., 2., 1., 3., 1., 3., 3.],
        [2., 4., 5., 1., 2., 2., 5., 3., 3., 1.],
        [1., 4., 3., 2., 3., 5., 3., 2., 4., 4.],
        [1., 3., 2., 1., 2., 3., 2., 2., 5., 2.],
        [4., 5., 5., 3., 1., 2., 1., 3., 4., 4.],
        [4., 2., 3., 2., 3., 4., 5., 1., 4., 3.],
        [3., 2., 1., 3., 2., 2., 5., 5., 5., 1.],
        [2., 5., 5., 2., 3., 5., 3., 4., 3., 2.],
        [1., 2., 2., 4., 4., 5., 5., 2., 5., 1.],
        [4., 4., 1., 5., 2., 4., 4., 2., 4., 2.],
        [4., 1., 5., 3., 1., 4., 5., 2., 2., 2.],
        [3., 2., 2., 1., 1., 1., 2., 4., 2., 2.],
        [1., 1., 3., 2., 2., 4., 3., 3., 2., 2.],
        [2., 4., 3., 1., 1., 2., 4., 2., 5., 2.],
        [1., 4., 3., 5., 4., 2., 4., 2., 2., 5.],
        [4., 4., 5., 5., 4., 3., 4., 1., 4., 5.],
        [5., 2., 2., 1., 4., 4., 3., 5., 5., 5.],
        [3., 1., 3., 1., 2., 5., 5., 4., 3., 1.],
        [5., 5., 5., 3., 5., 3., 3., 3., 3., 3.],
        [5., 3., 3., 3., 4., 3., 5., 4., 5., 1.],
        [5., 5., 3., 5., 4., 1., 2., 1., 3., 2.],
        [5., 3., 4., 1., 5., 2., 4., 4., 2., 1.],
        [1., 2., 1., 2., 1., 3., 4., 1., 2., 1.],
        [5., 3., 3., 1., 5., 1., 4., 5., 5., 1.],
        [2., 2., 5., 2., 3., 5., 1., 5., 4., 2.],
        [3., 3., 1., 2., 5., 5., 4., 4., 4., 5.],
        [5., 3., 5., 3., 5., 5., 3., 4., 1., 4.],
        [5., 3., 2., 5., 1., 2., 4., 3., 1., 4.],
        [5., 2., 5., 2., 4., 3., 2., 2., 4., 5.],
        [2., 3., 5., 2., 4., 5., 3., 1., 4., 3.],
        [3., 3., 1., 4., 5., 2., 1., 3., 4., 4.],
        [5., 2., 3., 4., 2., 2., 5., 4., 2., 1.],
        [3., 1., 1., 2., 5., 1., 5., 1., 5., 3.],
        [3., 5., 1., 5., 3., 4., 5., 5., 4., 3.],
        [2., 4., 1., 5., 4., 1., 4., 5., 1., 4.],
        [5., 5., 4., 1., 2., 2., 2., 5., 1., 3.],
        [5., 2., 3., 3., 4., 3., 3., 4., 3., 2.],
        [1., 5., 4., 4., 5., 3., 4., 2., 4., 1.],
        [2., 4., 1., 1., 4., 2., 1., 1., 3., 4.],
        [4., 4., 2., 3., 2., 2., 3., 5., 4., 1.],
        [5., 4., 3., 4., 5., 4., 5., 1., 5., 5.],
        [1., 1., 1., 1., 4., 4., 5., 2., 4., 2.],
        [2., 4., 1., 2., 4., 3., 5., 1., 4., 4.],
        [2., 5., 2., 1., 3., 5., 5., 4., 1., 4.]])
A.shape
torch.Size([100, 10])

Let us decompose A as WH. W is of shape (n, k) and H is of shape (k, n). We can write the above equation as: A = WH

# Randomly initialize A and B

W = torch.randn(n_users, 2, requires_grad=True)
H = torch.randn(2, n_movies, requires_grad=True)

# Compute the loss

loss = torch.norm(torch.mm(W, H) - A)
loss
tensor(110.7991, grad_fn=<LinalgVectorNormBackward0>)
pd.DataFrame(torch.mm(W, H).detach().numpy())
0 1 2 3 4 5 6 7 8 9
0 -1.733831 2.962563 -0.009936 -0.591927 2.442282 -0.533001 -0.500535 -0.777075 -0.427938 -0.050505
1 -1.605388 2.875087 0.171515 -0.770741 2.594427 -0.652586 -0.512620 -0.953935 -0.329276 -0.023146
2 0.159360 -0.289060 -0.022038 0.082685 -0.266777 0.069192 0.052250 0.101196 0.030828 0.001642
3 -3.637741 4.346515 -2.580031 1.911339 0.407374 1.134387 -0.353923 1.689442 -1.846117 -0.440421
4 1.706123 -2.207651 0.978520 -0.611154 -0.617788 -0.328235 0.228982 -0.492014 0.780051 0.176302
... ... ... ... ... ... ... ... ... ... ...
95 1.362280 -2.751892 -0.572962 1.180665 -2.989312 0.929992 0.551276 1.363935 0.121039 -0.036218
96 1.432943 -1.514237 1.287248 -1.086741 0.338909 -0.685343 0.065700 -1.016967 0.827599 0.208897
97 0.381086 0.345669 1.366953 -1.551476 1.978570 -1.084161 -0.261282 -1.599606 0.599750 0.189461
98 -3.364117 6.450801 0.942667 -2.333749 6.511645 -1.880903 -1.232884 -2.755594 -0.473888 0.027722
99 0.877697 -1.631227 -0.175042 0.521516 -1.568212 0.428319 0.302370 0.626961 0.149908 0.002033

100 rows × 10 columns

pd.DataFrame(A)
0 1 2 3 4 5 6 7 8 9
0 1.0 5.0 4.0 3.0 4.0 3.0 1.0 3.0 3.0 4.0
1 3.0 3.0 4.0 4.0 4.0 2.0 5.0 2.0 2.0 4.0
2 2.0 5.0 5.0 3.0 3.0 1.0 4.0 5.0 2.0 5.0
3 4.0 2.0 1.0 5.0 1.0 5.0 1.0 1.0 5.0 1.0
4 4.0 3.0 5.0 3.0 4.0 1.0 2.0 4.0 1.0 5.0
... ... ... ... ... ... ... ... ... ... ...
95 4.0 4.0 2.0 3.0 2.0 2.0 3.0 5.0 4.0 1.0
96 5.0 4.0 3.0 4.0 5.0 4.0 5.0 1.0 5.0 5.0
97 1.0 1.0 1.0 1.0 4.0 4.0 5.0 2.0 4.0 2.0
98 2.0 4.0 1.0 2.0 4.0 3.0 5.0 1.0 4.0 4.0
99 2.0 5.0 2.0 1.0 3.0 5.0 5.0 4.0 1.0 4.0

100 rows × 10 columns

# Optimizer

optimizer = optim.Adam([W, H], lr=0.01)

# Train the model

for i in range(1000):
    # Compute the loss
    loss = torch.norm(torch.mm(W, H) - A)
    
    # Zero the gradients
    optimizer.zero_grad()
    
    # Backpropagate
    loss.backward()
    
    # Update the parameters
    optimizer.step()
    
    # Print the loss
    if i % 10 == 0:
        print(loss.item())
110.79912567138672
108.5261001586914
106.6722412109375
104.86959075927734
102.57525634765625
99.17333984375
94.17224884033203
87.41124725341797
79.23120880126953
70.51470184326172
62.338897705078125
55.442466735839844
50.24897003173828
46.78135681152344
44.567237854003906
43.125022888183594
42.18255615234375
41.56866455078125
41.16780471801758
40.899253845214844
40.70838928222656
40.56161117553711
40.439727783203125
40.33248519897461
40.23463821411133
40.143455505371094
40.05741500854492
39.9755744934082
39.89726638793945
39.82197952270508
39.749305725097656
39.678890228271484
39.610443115234375
39.54371643066406
39.478515625
39.4146842956543
39.35213088989258
39.290809631347656
39.23072052001953
39.17192077636719
39.114505767822266
39.058589935302734
39.00433349609375
38.95188903808594
38.90142059326172
38.85308837890625
38.8070182800293
38.76333999633789
38.722137451171875
38.6834716796875
38.64735412597656
38.61379623413086
38.5827522277832
38.554161071777344
38.527931213378906
38.503971099853516
38.482154846191406
38.46236038208008
38.444454193115234
38.42829132080078
38.41374588012695
38.40068435668945
38.38897705078125
38.378501892089844
38.36914825439453
38.36080551147461
38.35336685180664
38.34674835205078
38.34086608886719
38.33564376831055
38.330997467041016
38.326881408691406
38.3232307434082
38.31999206542969
38.31712341308594
38.314579010009766
38.31232833862305
38.310325622558594
38.308555603027344
38.30698013305664
38.30558776855469
38.30434799194336
38.30324935913086
38.30227279663086
38.301395416259766
38.300621032714844
38.2999267578125
38.29930114746094
38.29874801635742
38.298248291015625
38.29779815673828
38.297393798828125
38.297027587890625
38.29669189453125
38.29639434814453
38.296119689941406
38.295867919921875
38.29563903808594
38.29542922973633
38.29523468017578
pd.DataFrame(torch.mm(W, H).detach().numpy()).head(2)
0 1 2 3 4 5 6 7 8 9
0 3.622862 3.536414 3.945925 2.985436 2.999869 2.749094 2.605484 2.758482 2.209353 3.361388
1 3.622746 3.547434 3.798462 3.133817 3.271342 3.176287 3.221505 3.073355 2.856927 3.383140
pd.DataFrame(A).head(2)
0 1 2 3 4 5 6 7 8 9
0 1.0 5.0 4.0 3.0 4.0 3.0 1.0 3.0 3.0 4.0
1 3.0 3.0 4.0 4.0 4.0 2.0 5.0 2.0 2.0 4.0
def factorize(A, k):
    """Factorize the matrix A into W and H
    A: input matrix of size (n_users, n_movies)
    k: number of latent features
    
    Returns W and H
    W: matrix of size (n_users, k)
    H: matrix of size (k, n_movies)
    """
    # Randomly initialize W and H
    W = torch.randn(A.shape[0], k, requires_grad=True)
    H = torch.randn(k, A.shape[1], requires_grad=True)
    
    # Optimizer
    optimizer = optim.Adam([W, H], lr=0.01)
    
    # Train the model
    for i in range(1000):
        # Compute the loss
        loss = torch.norm(torch.mm(W, H) - A)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Backpropagate
        loss.backward()
        
        # Update the parameters
        optimizer.step()
        
    return W, H, loss
for k in [1, 2, 3, 4, 5, 6, 9, 10, 11]:
    W, H, loss = factorize(A, k)
    print(k, loss.item())
1 42.103797912597656
2 38.35541534423828
3 34.45906448364258
4 30.72266387939453
5 27.430004119873047
6 23.318540573120117
9 9.963604927062988
10 0.16205351054668427
11 0.18895108997821808
pd.DataFrame(torch.mm(W,H).detach().numpy()).head(2)
0 1 2 3 4 5 6 7 8 9
0 1.014621 5.012642 4.012176 3.013239 4.012656 3.010292 1.012815 3.011410 3.015494 4.012996
1 2.987797 2.986945 3.989305 3.987498 3.988829 1.989375 4.989910 1.987138 1.986990 3.987026
pd.DataFrame(A).head(2)
0 1 2 3 4 5 6 7 8 9
0 1.0 5.0 4.0 3.0 4.0 3.0 1.0 3.0 3.0 4.0
1 3.0 3.0 4.0 4.0 4.0 2.0 5.0 2.0 2.0 4.0
# With missing values

# Randomly replace some entries with NaN

A = torch.randint(1, 6, (n_users, n_movies), dtype=torch.float)
A[torch.rand(A.shape) < 0.5] = float('nan')
A
tensor([[nan, 5., 5., nan, 1., 1., 4., 2., nan, 4.],
        [nan, 2., 2., 1., nan, nan, nan, nan, 3., nan],
        [nan, nan, 5., 1., nan, nan, 2., nan, nan, nan],
        [nan, nan, nan, 1., 4., 2., nan, 2., nan, 4.],
        [nan, nan, 2., nan, 3., 1., nan, nan, 4., 4.],
        [nan, 2., 5., nan, 2., nan, 1., nan, 3., nan],
        [nan, nan, nan, nan, 2., 1., nan, 3., nan, 5.],
        [nan, 4., nan, 1., 5., nan, 4., 5., 4., nan],
        [nan, nan, 2., 5., nan, nan, 5., nan, nan, 2.],
        [nan, nan, 3., 2., nan, 1., 1., 4., 5., nan],
        [nan, nan, 5., nan, nan, nan, 2., 2., nan, 3.],
        [nan, 5., 4., 2., nan, nan, nan, 1., 4., 3.],
        [nan, nan, 1., nan, 4., 4., nan, nan, 3., nan],
        [nan, 1., nan, nan, 3., nan, nan, nan, nan, 5.],
        [4., nan, 2., nan, nan, nan, nan, nan, 4., 3.],
        [4., nan, 3., nan, 3., nan, 4., 1., 1., nan],
        [5., nan, 3., nan, 3., nan, nan, 1., nan, nan],
        [2., 5., nan, 5., 3., 4., 3., 3., 5., 5.],
        [3., 3., nan, nan, nan, nan, nan, nan, 4., nan],
        [1., nan, 1., 3., 4., 1., nan, nan, 2., nan],
        [nan, nan, 5., nan, nan, nan, 2., nan, 1., nan],
        [3., 3., nan, nan, 2., 2., 3., 4., 4., nan],
        [2., nan, nan, 2., nan, nan, nan, nan, nan, 4.],
        [nan, 2., nan, 4., 5., 5., nan, 3., 5., nan],
        [2., 2., 2., 4., nan, 4., nan, 1., nan, nan],
        [5., 4., 5., 1., nan, 3., 5., 5., 1., 4.],
        [2., nan, nan, 2., nan, nan, nan, nan, nan, nan],
        [nan, 1., 5., nan, nan, nan, 2., 2., nan, nan],
        [2., 3., 4., nan, nan, nan, 1., 4., 4., nan],
        [2., nan, 1., 4., 1., 1., nan, 5., nan, 1.],
        [5., nan, 3., 1., 5., nan, nan, nan, 2., nan],
        [nan, 3., 4., nan, nan, 3., nan, 5., 5., nan],
        [4., 3., 2., 3., nan, nan, nan, nan, 1., nan],
        [1., 1., nan, 5., 1., 5., nan, nan, 5., nan],
        [4., 5., nan, nan, nan, 3., nan, 2., 5., 5.],
        [nan, 5., 5., nan, nan, nan, nan, 4., nan, nan],
        [5., nan, 4., 2., 3., 3., nan, nan, 2., 1.],
        [nan, nan, nan, nan, 5., nan, 2., 2., 4., nan],
        [3., 5., 1., 5., 3., nan, 5., nan, 2., 4.],
        [nan, 2., 4., nan, 1., 4., nan, 5., nan, 3.],
        [5., 5., 2., 5., nan, 4., nan, 2., nan, 3.],
        [5., nan, nan, 5., 1., 5., nan, nan, 3., 3.],
        [nan, 4., 2., 5., nan, 2., 3., nan, nan, 1.],
        [nan, nan, 4., 5., 5., nan, nan, 4., nan, 2.],
        [nan, nan, nan, nan, nan, nan, nan, 1., 3., nan],
        [3., 3., 1., 1., 1., 4., nan, nan, nan, 2.],
        [nan, nan, nan, 4., 5., nan, nan, 3., nan, nan],
        [5., nan, 5., 2., 4., nan, nan, nan, 5., nan],
        [nan, 1., 1., nan, 3., 1., nan, 1., nan, 1.],
        [nan, nan, 1., 2., nan, 3., nan, 2., 2., 4.],
        [nan, nan, 5., 1., 3., 2., 2., nan, 1., nan],
        [nan, 2., 2., nan, 4., nan, nan, nan, 3., nan],
        [nan, nan, nan, nan, 5., 4., nan, 3., nan, nan],
        [nan, 2., nan, nan, nan, nan, nan, 3., nan, nan],
        [nan, 2., 5., 2., nan, 3., 4., nan, 1., 2.],
        [1., nan, 1., nan, nan, 3., 4., 2., nan, 1.],
        [nan, 4., 4., 1., nan, nan, 5., nan, nan, 2.],
        [nan, nan, 3., nan, nan, 4., 1., nan, 3., nan],
        [nan, 1., nan, 3., 2., 3., nan, 2., nan, 4.],
        [2., 3., 3., nan, 1., 4., 3., nan, nan, 1.],
        [5., nan, 3., 1., 1., nan, 4., nan, 3., nan],
        [5., nan, nan, nan, nan, nan, 4., nan, nan, nan],
        [2., nan, nan, 2., nan, nan, 2., 3., nan, 4.],
        [3., nan, 4., 5., nan, nan, nan, 1., nan, 3.],
        [3., 3., 1., nan, 2., 5., 5., nan, 2., nan],
        [1., 1., 3., 1., nan, nan, 4., 3., 4., 5.],
        [4., nan, nan, 5., 2., nan, 1., nan, nan, nan],
        [nan, nan, 2., nan, 5., nan, 2., 1., nan, 4.],
        [nan, 5., nan, 4., 3., 2., 2., 1., nan, 4.],
        [nan, nan, 1., nan, nan, nan, 2., nan, nan, nan],
        [1., nan, nan, 4., 5., 3., nan, 2., nan, nan],
        [2., nan, 5., 1., 3., nan, 5., nan, nan, nan],
        [3., 5., nan, nan, nan, 5., 2., nan, nan, 2.],
        [nan, 2., 5., nan, nan, nan, 2., nan, 1., 2.],
        [nan, 3., 2., nan, nan, 4., nan, 2., nan, 5.],
        [nan, 4., 5., 2., 5., nan, nan, nan, nan, 5.],
        [3., 1., 5., 3., nan, nan, nan, 3., nan, 4.],
        [nan, 5., nan, 5., nan, 5., 1., 1., 3., 1.],
        [2., nan, nan, nan, 5., nan, nan, 4., nan, 4.],
        [nan, 3., 4., 2., nan, 2., nan, 2., nan, nan],
        [nan, 5., nan, 2., 2., 4., 5., 5., nan, nan],
        [nan, nan, nan, nan, nan, 5., nan, nan, 5., 4.],
        [nan, nan, nan, nan, nan, 2., 2., 2., 2., 4.],
        [nan, nan, nan, 2., 5., 3., 3., nan, nan, nan],
        [1., 2., 4., 2., nan, 2., 5., 4., nan, 5.],
        [4., nan, nan, 5., 4., nan, 5., 1., 3., nan],
        [4., nan, 1., 4., nan, nan, 2., nan, 4., 3.],
        [4., 2., 1., 3., nan, nan, 1., nan, nan, 1.],
        [nan, 1., 3., 1., 2., nan, 3., nan, 5., nan],
        [nan, 1., 1., nan, 1., 1., nan, 4., nan, nan],
        [nan, nan, nan, nan, 1., 5., nan, 5., 3., nan],
        [3., 4., nan, 4., 3., nan, 2., 1., nan, nan],
        [1., 1., nan, 2., nan, nan, nan, 2., 4., 1.],
        [nan, 4., 2., nan, 3., nan, 2., 1., nan, 2.],
        [2., nan, 5., nan, 3., nan, 5., 1., nan, nan],
        [3., 3., nan, nan, 3., 3., nan, 4., 2., nan],
        [2., nan, 5., nan, 5., 3., nan, nan, 5., nan],
        [nan, nan, nan, nan, 5., nan, nan, nan, nan, 5.],
        [4., nan, nan, 1., 4., 5., 1., nan, 2., 4.],
        [1., nan, nan, nan, 2., nan, 4., nan, nan, nan]])
W, H, loss = factorize(A, 2)
loss
tensor(nan, grad_fn=<LinalgVectorNormBackward0>)

As expected, the above function does not work. Our current loss function does not handle missing values.

mask = ~torch.isnan(A)
mask
tensor([[False,  True,  True, False,  True,  True,  True,  True, False,  True],
        [False,  True,  True,  True, False, False, False, False,  True, False],
        [False, False,  True,  True, False, False,  True, False, False, False],
        [False, False, False,  True,  True,  True, False,  True, False,  True],
        [False, False,  True, False,  True,  True, False, False,  True,  True],
        [False,  True,  True, False,  True, False,  True, False,  True, False],
        [False, False, False, False,  True,  True, False,  True, False,  True],
        [False,  True, False,  True,  True, False,  True,  True,  True, False],
        [False, False,  True,  True, False, False,  True, False, False,  True],
        [False, False,  True,  True, False,  True,  True,  True,  True, False],
        [False, False,  True, False, False, False,  True,  True, False,  True],
        [False,  True,  True,  True, False, False, False,  True,  True,  True],
        [False, False,  True, False,  True,  True, False, False,  True, False],
        [False,  True, False, False,  True, False, False, False, False,  True],
        [ True, False,  True, False, False, False, False, False,  True,  True],
        [ True, False,  True, False,  True, False,  True,  True,  True, False],
        [ True, False,  True, False,  True, False, False,  True, False, False],
        [ True,  True, False,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True, False, False, False, False, False, False,  True, False],
        [ True, False,  True,  True,  True,  True, False, False,  True, False],
        [False, False,  True, False, False, False,  True, False,  True, False],
        [ True,  True, False, False,  True,  True,  True,  True,  True, False],
        [ True, False, False,  True, False, False, False, False, False,  True],
        [False,  True, False,  True,  True,  True, False,  True,  True, False],
        [ True,  True,  True,  True, False,  True, False,  True, False, False],
        [ True,  True,  True,  True, False,  True,  True,  True,  True,  True],
        [ True, False, False,  True, False, False, False, False, False, False],
        [False,  True,  True, False, False, False,  True,  True, False, False],
        [ True,  True,  True, False, False, False,  True,  True,  True, False],
        [ True, False,  True,  True,  True,  True, False,  True, False,  True],
        [ True, False,  True,  True,  True, False, False, False,  True, False],
        [False,  True,  True, False, False,  True, False,  True,  True, False],
        [ True,  True,  True,  True, False, False, False, False,  True, False],
        [ True,  True, False,  True,  True,  True, False, False,  True, False],
        [ True,  True, False, False, False,  True, False,  True,  True,  True],
        [False,  True,  True, False, False, False, False,  True, False, False],
        [ True, False,  True,  True,  True,  True, False, False,  True,  True],
        [False, False, False, False,  True, False,  True,  True,  True, False],
        [ True,  True,  True,  True,  True, False,  True, False,  True,  True],
        [False,  True,  True, False,  True,  True, False,  True, False,  True],
        [ True,  True,  True,  True, False,  True, False,  True, False,  True],
        [ True, False, False,  True,  True,  True, False, False,  True,  True],
        [False,  True,  True,  True, False,  True,  True, False, False,  True],
        [False, False,  True,  True,  True, False, False,  True, False,  True],
        [False, False, False, False, False, False, False,  True,  True, False],
        [ True,  True,  True,  True,  True,  True, False, False, False,  True],
        [False, False, False,  True,  True, False, False,  True, False, False],
        [ True, False,  True,  True,  True, False, False, False,  True, False],
        [False,  True,  True, False,  True,  True, False,  True, False,  True],
        [False, False,  True,  True, False,  True, False,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True, False,  True, False],
        [False,  True,  True, False,  True, False, False, False,  True, False],
        [False, False, False, False,  True,  True, False,  True, False, False],
        [False,  True, False, False, False, False, False,  True, False, False],
        [False,  True,  True,  True, False,  True,  True, False,  True,  True],
        [ True, False,  True, False, False,  True,  True,  True, False,  True],
        [False,  True,  True,  True, False, False,  True, False, False,  True],
        [False, False,  True, False, False,  True,  True, False,  True, False],
        [False,  True, False,  True,  True,  True, False,  True, False,  True],
        [ True,  True,  True, False,  True,  True,  True, False, False,  True],
        [ True, False,  True,  True,  True, False,  True, False,  True, False],
        [ True, False, False, False, False, False,  True, False, False, False],
        [ True, False, False,  True, False, False,  True,  True, False,  True],
        [ True, False,  True,  True, False, False, False,  True, False,  True],
        [ True,  True,  True, False,  True,  True,  True, False,  True, False],
        [ True,  True,  True,  True, False, False,  True,  True,  True,  True],
        [ True, False, False,  True,  True, False,  True, False, False, False],
        [False, False,  True, False,  True, False,  True,  True, False,  True],
        [False,  True, False,  True,  True,  True,  True,  True, False,  True],
        [False, False,  True, False, False, False,  True, False, False, False],
        [ True, False, False,  True,  True,  True, False,  True, False, False],
        [ True, False,  True,  True,  True, False,  True, False, False, False],
        [ True,  True, False, False, False,  True,  True, False, False,  True],
        [False,  True,  True, False, False, False,  True, False,  True,  True],
        [False,  True,  True, False, False,  True, False,  True, False,  True],
        [False,  True,  True,  True,  True, False, False, False, False,  True],
        [ True,  True,  True,  True, False, False, False,  True, False,  True],
        [False,  True, False,  True, False,  True,  True,  True,  True,  True],
        [ True, False, False, False,  True, False, False,  True, False,  True],
        [False,  True,  True,  True, False,  True, False,  True, False, False],
        [False,  True, False,  True,  True,  True,  True,  True, False, False],
        [False, False, False, False, False,  True, False, False,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True, False,  True,  True,  True, False,  True],
        [ True, False, False,  True,  True, False,  True,  True,  True, False],
        [ True, False,  True,  True, False, False,  True, False,  True,  True],
        [ True,  True,  True,  True, False, False,  True, False, False,  True],
        [False,  True,  True,  True,  True, False,  True, False,  True, False],
        [False,  True,  True, False,  True,  True, False,  True, False, False],
        [False, False, False, False,  True,  True, False,  True,  True, False],
        [ True,  True, False,  True,  True, False,  True,  True, False, False],
        [ True,  True, False,  True, False, False, False,  True,  True,  True],
        [False,  True,  True, False,  True, False,  True,  True, False,  True],
        [ True, False,  True, False,  True, False,  True,  True, False, False],
        [ True,  True, False, False,  True,  True, False,  True,  True, False],
        [ True, False,  True, False,  True,  True, False, False,  True, False],
        [False, False, False, False,  True, False, False, False, False,  True],
        [ True, False, False,  True,  True,  True,  True, False,  True,  True],
        [ True, False, False, False,  True, False,  True, False, False, False]])
mask.sum()
tensor(517)
W = torch.randn(A.shape[0], k, requires_grad=True)
H = torch.randn(k, A.shape[1],  requires_grad=True)

diff_matrix = torch.mm(W, H)-A
diff_matrix.shape
torch.Size([100, 10])
# Mask the matrix
diff_matrix[mask].shape
torch.Size([517])
# Modify the loss function to ignore NaN values

def factorize(A, k):
    """Factorize the matrix D into A and B"""
    # Randomly initialize A and B
    W = torch.randn(A.shape[0], k, requires_grad=True)
    H = torch.randn(k, A.shape[1], requires_grad=True)
    # Optimizer
    optimizer = optim.Adam([W, H], lr=0.01)
    
    # Train the model
    for i in range(1000):
        # Compute the loss
        diff_matrix = torch.mm(W, H) - A
        diff_vector = diff_matrix[mask]
        loss = torch.norm(diff_vector)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Backpropagate
        loss.backward()
        
        # Update the parameters
        optimizer.step()
        
    return W, H, loss
W, H, loss = factorize(A, 5)
loss
tensor(7.1147, grad_fn=<LinalgVectorNormBackward0>)
torch.mm(W, H)
tensor([[-3.7026e+00,  5.1163e+00,  4.9660e+00, -5.8455e-01,  1.0200e+00,
          9.1470e-01,  3.8727e+00,  2.2007e+00,  1.3782e+01,  3.9267e+00],
        [ 6.4588e-02,  2.0002e+00,  1.9979e+00,  1.0001e+00,  3.6942e+00,
          1.1512e+00,  3.8181e+00,  3.5421e+00,  3.0014e+00,  5.0035e+00],
        [-3.3540e-01, -1.0050e+00,  5.0000e+00,  1.0010e+00,  1.7959e-01,
          8.6454e-01,  1.9994e+00,  2.9560e+00,  1.0767e+01, -3.4057e-01],
        [ 2.6321e+00,  3.5523e+00,  2.1706e+00,  9.9769e-01,  3.9995e+00,
          2.0030e+00,  2.8652e+00,  2.0006e+00,  3.5503e-02,  3.9991e+00],
        [ 5.3195e-02, -1.3399e+00,  1.9455e+00,  2.2700e+00,  3.0539e+00,
          1.0086e+00,  3.8842e+00,  5.1128e+00,  4.0248e+00,  3.9763e+00],
        [ 5.5045e+00,  2.0044e+00,  4.9388e+00,  1.6216e+00,  2.0865e+00,
          2.8957e+00,  9.2770e-01,  8.8714e-01,  3.0420e+00, -9.2438e-01],
        [ 6.1366e-01, -1.3394e+00, -6.3292e+00,  4.0202e+00,  2.0028e+00,
          1.0014e+00,  2.1030e+00,  3.0001e+00, -1.0639e+01,  4.9979e+00],
        [ 1.6036e+00,  3.6518e+00,  3.7664e+00,  1.2319e+00,  5.1780e+00,
          2.1295e+00,  4.7475e+00,  4.0635e+00,  4.0645e+00,  5.9223e+00],
        [-4.1593e+00, -1.1450e+01,  2.0064e+00,  5.0104e+00,  1.7877e-01,
         -6.8356e-01,  4.9576e+00,  1.0525e+01,  1.2403e+01,  2.0250e+00],
        [ 1.7769e+00, -4.9137e+00,  3.1228e+00,  2.2575e+00,  5.8793e-01,
          6.6961e-01,  1.2031e+00,  3.7871e+00,  4.9702e+00, -1.4083e+00],
        [ 9.1273e+00,  1.4212e+00,  5.0018e+00, -1.0804e+00,  8.9608e+00,
          1.6892e+00,  2.0130e+00,  1.9936e+00, -7.2410e+00,  2.9931e+00],
        [ 3.2326e+00,  5.2245e+00,  3.8534e+00,  1.8050e+00,  2.7570e+00,
          3.2246e+00,  2.7488e+00,  1.4501e+00,  3.9743e+00,  2.7026e+00],
        [ 9.2220e-01,  6.5451e+00,  9.9373e-01,  4.2199e+00,  4.0030e+00,
          4.0019e+00,  5.8552e+00,  4.3514e+00,  3.0028e+00,  8.0414e+00],
        [-2.3552e+00,  1.0013e+00,  9.6726e-01, -4.6631e-01,  3.0005e+00,
         -4.1134e-01,  3.3274e+00,  3.0678e+00,  3.0946e+00,  4.9978e+00],
        [ 4.0009e+00,  5.9036e+00,  1.9983e+00,  6.0610e+00,  1.0184e+00,
          5.6167e+00,  3.5338e+00,  2.5504e+00,  4.0003e+00,  3.0002e+00],
        [ 4.1839e+00,  9.1531e+00,  2.5008e+00,  3.1248e+00,  3.4309e+00,
          4.8701e+00,  3.6236e+00,  9.7109e-01,  1.3077e+00,  4.9725e+00],
        [ 5.0034e+00, -2.6358e+00,  2.9996e+00, -3.1773e-01,  2.9981e+00,
          3.3085e-01, -3.0125e-01,  1.0002e+00, -2.6015e+00, -1.5107e+00],
        [ 2.1562e+00,  5.1078e+00,  2.3091e+00,  4.3035e+00,  2.3000e+00,
          4.1470e+00,  4.2497e+00,  3.3351e+00,  4.7425e+00,  4.5406e+00],
        [ 2.9999e+00,  2.9999e+00,  3.0624e+00,  4.4992e+00,  4.5262e+00,
          3.8824e+00,  5.4591e+00,  5.5541e+00,  4.0005e+00,  5.9015e+00],
        [ 8.1177e-01, -2.9699e+00,  1.2606e+00,  3.0946e+00,  3.9370e+00,
          1.0488e+00,  4.3280e+00,  6.3317e+00,  1.8619e+00,  4.5940e+00],
        [ 2.1477e+00,  2.4168e+00,  5.0014e+00, -3.6888e+00,  5.9741e+00,
         -5.6343e-01,  1.9977e+00,  9.9599e-01,  1.0008e+00,  3.2156e+00],
        [ 2.5660e+00,  2.4872e+00,  2.9711e+00,  3.0048e+00,  2.5035e+00,
          2.9551e+00,  3.2466e+00,  3.1642e+00,  4.0380e+00,  2.8930e+00],
        [ 2.0003e+00,  1.1864e+01, -1.4367e+00,  1.9999e+00,  4.6078e-01,
          4.1701e+00,  1.3254e+00, -3.0082e+00, -2.7641e+00,  3.9993e+00],
        [ 1.0449e+01,  2.0032e+00,  8.9659e+00,  3.6977e+00,  4.8085e+00,
          5.2882e+00,  2.4920e+00,  3.2403e+00,  4.9395e+00, -8.7002e-01],
        [ 2.0717e+00,  2.1115e+00,  2.0717e+00,  4.2801e+00, -2.0380e+00,
          3.5172e+00,  9.3884e-01,  9.7155e-01,  6.2354e+00, -1.3328e+00],
        [ 5.2955e+00,  3.4266e+00,  5.1299e+00,  1.4202e+00,  6.8671e+00,
          2.9375e+00,  4.3992e+00,  3.9995e+00,  1.2817e+00,  5.0996e+00],
        [ 2.0001e+00,  1.4791e+00,  2.3388e+00,  2.0000e+00, -6.6995e-01,
          2.1166e+00,  4.5259e-01,  3.8568e-01,  4.1443e+00, -1.1447e+00],
        [ 2.9065e+00,  1.0088e+00,  5.0003e+00, -4.4267e-01,  3.6696e+00,
          1.0351e+00,  1.9735e+00,  2.0205e+00,  3.8361e+00,  1.3199e+00],
        [ 2.3390e+00,  2.3721e+00,  3.7185e+00,  1.3874e+00,  2.8777e+00,
          2.1103e+00,  2.7381e+00,  2.4701e+00,  4.1805e+00,  2.4318e+00],
        [ 1.7971e+00, -4.8495e+00,  9.1239e-01,  3.9631e+00,  1.2733e+00,
          1.2781e+00,  2.2735e+00,  5.0206e+00,  1.8186e+00,  7.7334e-01],
        [ 4.4556e+00,  1.2257e+01,  3.8113e+00,  1.3722e+00,  4.8218e+00,
          4.8774e+00,  3.8956e+00, -1.3451e-01,  1.5725e+00,  6.1126e+00],
        [ 2.3871e+00,  3.0010e+00,  3.9991e+00,  2.8334e+00,  4.8395e+00,
          2.9996e+00,  5.1329e+00,  5.0009e+00,  5.0012e+00,  5.6052e+00],
        [ 3.9920e+00,  3.0004e+00,  2.0131e+00,  3.0049e+00,  1.3189e+00,
          3.2595e+00,  1.5054e+00,  1.0929e+00,  9.9195e-01,  8.6129e-01],
        [ 1.3517e+00,  1.3225e+00,  1.2075e+00,  5.6267e+00,  9.9029e-01,
          3.6713e+00,  4.0205e+00,  4.6894e+00,  5.0908e+00,  3.3918e+00],
        [ 3.9141e+00,  5.4294e+00,  6.3565e+00,  3.9410e-02,  5.9122e+00,
          2.6582e+00,  4.0295e+00,  2.6446e+00,  4.8764e+00,  4.5279e+00],
        [ 5.9703e+00,  5.0000e+00,  5.0006e+00,  5.4552e+00,  3.4764e+00,
          5.7107e+00,  4.4197e+00,  3.9984e+00,  5.4614e+00,  3.3112e+00],
        [ 4.9808e+00,  2.5274e+00,  4.0592e+00,  2.0167e+00,  2.9650e+00,
          2.9926e+00,  1.9145e+00,  1.7219e+00,  1.9706e+00,  1.0164e+00],
        [ 3.3675e+00,  7.9931e-01,  6.4143e+00, -1.9437e+00,  5.0061e+00,
          3.9840e-01,  1.9918e+00,  2.0016e+00,  4.0036e+00,  1.4088e+00],
        [ 3.1524e+00,  4.8411e+00,  1.3102e+00,  5.0749e+00,  2.6352e+00,
          4.5228e+00,  4.2075e+00,  3.4787e+00,  2.0158e+00,  4.7858e+00],
        [ 1.1647e+00,  2.0442e+00,  3.9774e+00,  5.2040e+00,  1.0064e+00,
          3.9541e+00,  4.5550e+00,  5.0695e+00,  1.0370e+01,  2.9545e+00],
        [ 4.9155e+00,  4.8887e+00,  1.8611e+00,  4.5726e+00,  2.4214e+00,
          4.6615e+00,  2.9405e+00,  2.1223e+00,  4.9785e-01,  2.9426e+00],
        [ 4.6095e+00,  7.0277e+00,  2.3640e+00,  4.9826e+00,  1.3639e+00,
          5.4668e+00,  2.9430e+00,  1.3600e+00,  2.9275e+00,  2.6572e+00],
        [-9.4300e+00,  3.9400e+00,  1.8115e+00,  4.6790e+00, -7.7135e+00,
          2.3401e+00,  3.3902e+00,  2.5588e+00,  2.3759e+01,  7.8572e-01],
        [ 8.8489e+00,  1.3677e+00,  4.0011e+00,  4.9984e+00,  5.0002e+00,
          4.8811e+00,  3.1127e+00,  4.0020e+00, -1.6249e+00,  1.9996e+00],
        [-5.3638e+00, -1.6611e+01, -1.2622e+00, -2.9465e+00, -5.3122e+00,
         -6.6081e+00, -4.5960e+00,  1.0021e+00,  3.0004e+00, -8.7048e+00],
        [ 2.9505e+00,  3.4944e+00,  1.0966e+00,  2.0104e+00,  1.4645e+00,
          2.5452e+00,  1.3437e+00,  4.9608e-01, -3.2980e-01,  1.5523e+00],
        [ 8.8610e+00,  3.4098e+00,  5.4534e+00,  4.0010e+00,  5.0039e+00,
          5.0527e+00,  2.9819e+00,  2.9969e+00,  4.8445e-01,  1.7580e+00],
        [ 4.7645e+00,  7.2438e+00,  5.3527e+00,  2.1598e+00,  3.9257e+00,
          4.3363e+00,  3.5150e+00,  1.6544e+00,  4.8212e+00,  3.4858e+00],
        [ 3.6301e+00,  8.2968e-01,  1.0952e+00,  5.1852e-01,  2.9273e+00,
          1.1501e+00,  7.9266e-01,  7.6517e-01, -3.4503e+00,  1.2149e+00],
        [ 1.2349e+00,  4.7494e+00,  1.1553e+00,  2.2902e+00,  2.1858e+00,
          2.6806e+00,  3.0526e+00,  1.8136e+00,  1.9800e+00,  4.0943e+00],
        [ 5.4207e+00,  4.6174e-01,  4.6315e+00,  1.0934e+00,  3.5323e+00,
          2.1171e+00,  1.4216e+00,  1.9254e+00,  1.2497e+00,  1.5768e-01],
        [ 8.0981e-01,  1.9982e+00,  1.9823e+00,  2.3342e+00,  4.0156e+00,
          1.9755e+00,  4.5111e+00,  4.4963e+00,  3.0107e+00,  5.5846e+00],
        [ 3.4230e+00,  6.7811e+00,  1.3001e+00,  3.2712e+00,  5.0004e+00,
          4.0008e+00,  4.7148e+00,  2.9985e+00, -8.9564e-01,  6.9733e+00],
        [ 4.3992e+00,  1.9999e+00,  1.7722e+00,  3.3742e+00,  3.4675e+00,
          3.1595e+00,  2.8761e+00,  3.0001e+00, -8.2256e-01,  3.0558e+00],
        [ 6.3852e+00,  1.8676e+00,  5.1417e+00,  2.1083e+00,  5.4183e+00,
          3.1672e+00,  3.1282e+00,  3.3945e+00,  1.0754e+00,  2.5551e+00],
        [ 8.1366e-01,  2.6100e+00,  1.2431e+00,  4.6053e+00, -7.3305e-01,
          3.4299e+00,  2.5706e+00,  2.4739e+00,  5.9303e+00,  1.5624e+00],
        [-1.5185e+01,  3.9601e+00,  4.1178e+00,  1.1159e+00, -7.8194e+00,
         -2.6153e-01,  4.5528e+00,  3.2349e+00,  3.2557e+01,  2.2662e+00],
        [ 4.1539e+00,  6.2482e+00,  3.0004e+00,  2.5108e+00,  4.1870e-01,
          4.0017e+00,  9.9885e-01, -7.3851e-01,  2.9998e+00,  1.7461e-01],
        [ 2.4292e+00,  1.3324e+00, -2.4186e+00,  3.4744e+00,  2.2288e+00,
          2.2009e+00,  2.0940e+00,  2.1399e+00, -5.7555e+00,  3.6885e+00],
        [ 2.2950e+00,  2.9436e+00,  3.1712e+00,  4.1268e+00,  5.4157e-01,
          3.7216e+00,  2.8408e+00,  2.7181e+00,  6.7460e+00,  1.4747e+00],
        [ 5.0664e+00,  4.5652e+01,  2.7328e+00,  1.0374e+00,  1.2891e+00,
          1.3143e+01,  3.7184e+00, -1.3608e+01,  3.1697e+00,  1.1363e+01],
        [ 5.0006e+00,  6.9239e+00,  9.1918e-01,  7.2628e+00,  1.5445e+00,
          6.5242e+00,  4.0004e+00,  2.7856e+00,  1.5080e+00,  4.0918e+00],
        [ 2.0015e+00, -1.7211e+00, -2.6718e+00,  1.9970e+00,  3.6355e+00,
          4.8944e-01,  2.0002e+00,  3.0038e+00, -7.9397e+00,  3.9988e+00],
        [ 2.9915e+00,  1.0068e+01,  4.0126e+00,  5.0140e+00,  2.1576e-01,
          6.2162e+00,  3.6692e+00,  9.8564e-01,  9.0477e+00,  2.9998e+00],
        [ 3.1704e+00,  3.0173e+00,  6.3333e-01,  6.8363e+00,  2.2725e+00,
          4.9317e+00,  4.7652e+00,  5.0099e+00,  2.2206e+00,  4.9007e+00],
        [ 7.5330e-01,  1.4262e+00,  3.0599e+00,  8.5747e-01,  4.0712e+00,
          1.1776e+00,  3.8169e+00,  3.8298e+00,  3.8094e+00,  4.5052e+00],
        [ 4.0428e+00, -2.0650e+01,  2.4385e+00,  4.9425e+00,  1.9279e+00,
         -1.5739e+00,  1.0876e+00,  1.0297e+01,  9.8782e-01, -4.4288e+00],
        [ 1.0773e+00,  2.2150e+00,  2.0096e+00, -2.4710e+00,  4.9950e+00,
         -5.0779e-01,  1.9650e+00,  1.0100e+00, -1.8899e+00,  4.0205e+00],
        [ 3.4996e+00,  4.5594e+00, -2.5208e+00,  3.3402e+00,  2.6908e+00,
          3.0928e+00,  2.0393e+00,  8.1359e-01, -7.3751e+00,  4.3822e+00],
        [-4.8442e+00,  3.4518e+00,  9.9972e-01, -1.2881e+00, -3.5755e-01,
         -5.3701e-01,  1.9997e+00,  4.9337e-01,  7.5069e+00,  3.0513e+00],
        [ 1.0026e+00,  7.6105e+00, -7.1104e+00,  3.9509e+00,  4.9733e+00,
          3.0444e+00,  4.5575e+00,  2.0356e+00, -1.3359e+01,  1.0958e+01],
        [ 1.9708e+00,  2.2459e+01,  5.0115e+00,  1.0357e+00,  3.0386e+00,
          7.0322e+00,  4.9501e+00, -3.0939e+00,  8.5776e+00,  8.3022e+00],
        [ 3.0000e+00,  5.0024e+00, -1.4123e+00,  6.3346e+00, -9.3349e-01,
          4.9944e+00,  2.0182e+00,  1.1012e+00, -1.8768e-01,  1.9884e+00],
        [ 4.4226e+00,  2.0069e+00,  5.0098e+00, -8.1680e-01,  5.0611e+00,
          1.2995e+00,  1.9966e+00,  1.6425e+00,  9.9608e-01,  1.9906e+00],
        [ 1.1858e+01,  3.1806e+00,  1.9399e+00,  2.0242e+00,  9.7272e+00,
          3.8295e+00,  2.5867e+00,  2.2628e+00, -1.4185e+01,  4.8150e+00],
        [ 3.0139e+00,  4.0015e+00,  4.9996e+00,  2.0000e+00,  5.0004e+00,
          3.0423e+00,  4.6969e+00,  4.0977e+00,  5.4217e+00,  4.9981e+00],
        [ 3.1070e+00,  1.5608e+00,  4.4961e+00,  2.4194e+00,  3.9092e+00,
          2.6510e+00,  3.8237e+00,  4.1656e+00,  5.0091e+00,  3.2677e+00],
        [ 3.6123e+00,  5.0130e+00,  1.2513e+00,  5.0780e+00, -6.2890e-01,
          4.7738e+00,  1.7050e+00,  7.5555e-01,  2.9369e+00,  7.0034e-01],
        [ 1.9988e+00, -1.0755e+00,  2.1010e+00,  4.8154e-01,  5.0034e+00,
          4.1238e-01,  3.1198e+00,  4.0017e+00, -6.4504e-01,  3.9979e+00],
        [-2.4511e+00,  3.0049e+00,  4.0049e+00,  2.0168e+00, -1.7567e+00,
          1.9773e+00,  2.5612e+00,  1.9963e+00,  1.3746e+01,  7.7954e-01],
        [ 1.2900e+00,  4.8906e+00,  1.2285e+01,  2.0178e+00,  2.0017e+00,
          4.0381e+00,  5.2765e+00,  4.7427e+00,  2.3263e+01,  1.4459e+00],
        [ 3.4409e+00,  1.6323e+00,  2.1447e+00,  7.1434e+00,  2.1349e+00,
          4.9995e+00,  4.9455e+00,  5.8905e+00,  5.0000e+00,  4.0006e+00],
        [-8.3402e+00,  1.6671e+00, -9.7906e+00,  7.5550e+00, -7.1860e+00,
          1.9594e+00,  2.1880e+00,  1.9448e+00,  1.9841e+00,  3.9175e+00],
        [ 6.1855e+00,  1.0277e+00,  6.0344e+00,  2.0010e+00,  5.0009e+00,
          2.9985e+00,  3.0008e+00,  3.6213e+00,  3.0914e+00,  1.7165e+00],
        [ 9.9264e-01,  2.2125e+00,  3.9075e+00,  1.8425e+00,  4.0571e+00,
          2.0497e+00,  4.5854e+00,  4.5379e+00,  6.0676e+00,  4.9087e+00],
        [ 4.1464e+00,  1.4062e+01,  2.4284e+00,  4.7693e+00,  3.6919e+00,
          6.9404e+00,  5.3224e+00,  1.0755e+00,  2.9501e+00,  7.6420e+00],
        [ 3.7800e+00,  1.9199e+01,  1.3223e+00,  4.2259e+00, -1.6001e+00,
          7.9395e+00,  1.7065e+00, -4.9606e+00,  3.8810e+00,  3.1246e+00],
        [ 4.0234e+00,  2.0156e+00,  9.3090e-01,  2.9352e+00,  1.4458e+00,
          2.8364e+00,  1.1944e+00,  1.0778e+00, -1.3489e+00,  8.8513e-01],
        [ 2.9993e-01,  1.0106e+00,  2.8743e+00,  1.0596e+00,  2.1897e+00,
          1.1528e+00,  2.7986e+00,  2.9140e+00,  5.0939e+00,  2.6367e+00],
        [-2.7522e+00,  1.0011e+00,  9.9939e-01,  2.1638e+00,  9.9803e-01,
          9.9864e-01,  3.8559e+00,  4.0029e+00,  6.9205e+00,  4.3317e+00],
        [ 4.8687e+00, -1.5672e-01,  1.6703e+00,  7.5356e+00,  9.9502e-01,
          4.9990e+00,  3.3880e+00,  5.0098e+00,  2.9938e+00,  1.4576e+00],
        [ 3.0192e+00,  3.9819e+00, -5.9198e+00,  3.9832e+00,  2.9753e+00,
          2.7502e+00,  2.0750e+00,  9.5908e-01, -1.2959e+01,  5.9251e+00],
        [ 1.0148e+00,  9.7130e-01,  1.9768e+00,  2.0122e+00,  6.6718e-01,
          1.7201e+00,  1.7755e+00,  1.9466e+00,  4.0138e+00,  1.0400e+00],
        [ 5.3930e+00,  4.0298e+00,  2.0109e+00,  2.7548e+00,  2.9962e+00,
          3.5641e+00,  1.8677e+00,  1.0778e+00, -1.7539e+00,  2.0255e+00],
        [ 1.9953e+00,  1.2843e+01,  5.0056e+00,  2.5516e+00,  2.9983e+00,
          5.4422e+00,  5.0090e+00,  9.8768e-01,  8.9960e+00,  6.2394e+00],
        [ 2.8324e+00,  2.8434e+00,  1.9018e+00,  3.8000e+00,  3.1922e+00,
          3.3458e+00,  3.8883e+00,  3.7678e+00,  2.0126e+00,  4.2686e+00],
        [ 2.0652e+00,  8.7485e+00,  4.9003e+00, -7.1472e-02,  5.0456e+00,
          2.9448e+00,  4.2957e+00,  1.4602e+00,  5.0622e+00,  6.0046e+00],
        [ 1.2368e+00, -8.0124e-01,  4.5406e-01,  5.8132e-01,  5.0025e+00,
          1.8401e-01,  3.2316e+00,  3.9075e+00, -2.7376e+00,  4.9980e+00],
        [ 4.4707e+00,  1.3100e+01,  3.8913e+00,  5.0670e-01,  3.1494e+00,
          4.7096e+00,  2.2617e+00, -2.3287e+00,  1.8589e+00,  3.9039e+00],
        [ 9.9929e-01,  1.5253e+01,  2.1834e-03,  2.6820e+00,  2.0005e+00,
          5.3245e+00,  3.9995e+00, -1.3442e+00,  1.4326e+00,  7.5636e+00]],
       grad_fn=<MmBackward0>)
# Now use matrix factorization to predict the ratings

import torch
import torch.nn as nn
import torch.nn.functional as F

# Create a class for the model

class MatrixFactorization(nn.Module):
    def __init__(self, n_users, n_movies, n_factors=20):
        super().__init__()
        self.user_factors = nn.Embedding(n_users, n_factors)
        self.movie_factors = nn.Embedding(n_movies, n_factors)

    def forward(self, user, movie):
        return (self.user_factors(user) * self.movie_factors(movie)).sum(1)      
model = MatrixFactorization(n_users, n_movies, 2)
model
MatrixFactorization(
  (user_factors): Embedding(100, 2)
  (movie_factors): Embedding(10, 2)
)
model(torch.tensor([0]), torch.tensor([2]))
tensor([-0.0271], grad_fn=<SumBackward1>)
A[0, 2]
tensor(5.)
type(A)
torch.Tensor
mask = ~torch.isnan(A)

# Get the indices of the non-NaN values
i, j = torch.where(mask)

# Get the values of the non-NaN values
v = A[mask]

# Store in PyTorch tensors
users = i.to(torch.int64)
movies = j.to(torch.int64)
ratings = v.to(torch.float32)
pd.DataFrame({'user': users, 'movie': movies, 'rating': ratings})
user movie rating
0 0 1 5.0
1 0 2 5.0
2 0 4 1.0
3 0 5 1.0
4 0 6 4.0
... ... ... ...
512 98 8 2.0
513 98 9 4.0
514 99 0 1.0
515 99 4 2.0
516 99 6 4.0

517 rows × 3 columns

# Fit the Matrix Factorization model
model = MatrixFactorization(n_users, n_movies, 4)
optimizer = optim.Adam(model.parameters(), lr=0.01)

for i in range(1000):
    # Compute the loss
    pred = model(users, movies)
    loss = F.mse_loss(pred, ratings)
    
    # Zero the gradients
    optimizer.zero_grad()
    
    # Backpropagate
    loss.backward()
    
    # Update the parameters
    optimizer.step()
    
    # Print the loss
    if i % 100 == 0:
        print(loss.item())
14.604362487792969
4.332712650299072
1.0960761308670044
0.6966323852539062
0.5388827919960022
0.45243579149246216
0.4012693464756012
0.3728969395160675
0.35568001866340637
0.34289655089378357
model(users, movies)
tensor([3.5693, 4.5338, 2.6934, 1.8316, 4.8915, 2.0194, 2.7778, 1.8601, 2.1124,
        1.1378, 2.9079, 5.0470, 0.9911, 1.9791, 1.0050, 3.9618, 2.0085, 2.0034,
        4.0113, 1.9218, 2.9801, 1.0432, 3.8993, 4.1292, 1.9357, 3.8285, 3.5266,
        1.3640, 2.3989, 2.4166, 0.8559, 3.3685, 4.4493, 3.4018, 1.4722, 4.8378,
        4.6684, 4.4473, 4.3097, 2.0022, 5.0147, 5.0113, 1.9599, 2.8305, 1.4493,
        1.6750, 0.9520, 3.8460, 5.1279, 4.7453, 2.1484, 1.8009, 3.2104, 4.2068,
        4.5473, 2.9229, 1.1817, 3.1108, 3.2157, 1.2238, 3.6272, 3.9029, 3.2554,
        0.9945, 3.0062, 5.0030, 4.1144, 1.6314, 3.7945, 3.3091, 4.0727, 3.6212,
        2.4359, 3.5707, 1.2826, 0.9663, 4.9973, 3.0163, 2.9916, 1.0014, 3.1734,
        3.8712, 4.3364, 3.7119, 4.5313, 2.3875, 4.0274, 4.7121, 4.3851, 2.8072,
        3.2066, 3.9684, 0.9307, 1.4160, 2.7484, 2.8771, 1.2753, 2.5825, 4.9857,
        2.0109, 1.0080, 2.2618, 2.7936, 2.5859, 3.0972, 3.1443, 3.8655, 3.2023,
        2.0061, 1.9866, 4.0068, 2.7861, 4.2803, 4.5452, 3.8177, 2.9951, 5.6040,
        1.9013, 2.4498, 1.9155, 4.3239, 3.2993, 1.1275, 4.1460, 4.3619, 4.5913,
        1.1365, 3.0865, 6.0992, 3.9207, 1.8325, 3.8526, 2.0000, 2.0000, 1.0255,
        5.0484, 1.9858, 1.9716, 2.0854, 2.7726, 3.5053, 1.7365, 3.2607, 4.7006,
        1.8830, 0.3378, 2.6222, 0.6491, 3.1144, 3.8404, 2.0632, 4.3311, 4.6138,
        1.6119, 3.0444, 2.4498, 2.7860, 4.0832, 3.2681, 5.0069, 4.8745, 3.6048,
        3.5065, 1.8871, 2.9108, 0.9982, 0.6787, 1.4787, 5.6890, 1.5117, 4.0205,
        4.6624, 4.1051, 4.5437, 3.5410, 1.8761, 4.6653, 5.2577, 5.0311, 4.9531,
        4.0022, 4.7229, 4.0406, 2.5056, 3.2014, 2.9312, 1.3051, 1.3258, 4.9233,
        2.0871, 1.8761, 4.1304, 3.9279, 4.5547, 1.6639, 4.3163, 1.8149, 4.7431,
        3.4398, 3.5483, 3.2541, 2.8282, 1.5230, 2.8214, 4.9214, 3.5371, 4.4522,
        4.7928, 2.4932, 4.8619, 4.7693, 2.3121, 2.3021, 4.6609, 4.7478, 1.7768,
        5.3910, 2.8929, 2.4920, 4.1654, 1.8382, 3.4456, 3.7108, 2.3989, 1.0480,
        4.2870, 5.2730, 3.8976, 3.2956, 3.1010, 0.9666, 3.0227, 2.9491, 3.2218,
        1.4033, 1.7050, 0.5457, 3.1223, 2.1794, 3.9939, 5.0190, 2.9944, 4.5670,
        5.4131, 2.4720, 4.0239, 4.5689, 0.9386, 1.1827, 2.1553, 1.4232, 0.4127,
        1.7502, 1.5601, 2.0953, 2.3530, 2.7085, 2.7157, 2.7900, 4.6793, 1.1711,
        3.4493, 1.7701, 2.1576, 0.8294, 1.9908, 2.1975, 3.6396, 3.1952, 4.9882,
        4.0091, 2.9998, 1.9985, 3.0063, 3.7149, 4.2018, 1.1992, 2.2550, 3.5132,
        1.3794, 2.4434, 2.0220, 0.6691, 1.9927, 3.4666, 2.8342, 1.0393, 4.1834,
        4.0819, 0.8652, 4.8585, 2.0225, 2.9662, 3.9794, 1.0412, 3.0180, 1.0117,
        3.1728, 2.6027, 2.5184, 2.5120, 3.3483, 3.0602, 3.2139, 1.9144, 1.4899,
        2.8253, 2.9160, 1.6410, 3.8484, 2.9637, 2.1272, 2.0490, 4.4915, 1.6853,
        5.0641, 3.9315, 1.8812, 2.1028, 2.1203, 2.8898, 4.0215, 3.3032, 3.6131,
        4.8380, 0.7513, 3.4182, 3.0527, 3.6200, 1.3040, 1.1657, 4.2705, 4.8569,
        2.7037, 0.9142, 1.5261, 3.3970, 0.6980, 3.1131, 3.8348, 3.4478, 4.9488,
        4.0077, 4.9673, 2.0816, 0.9995, 3.4965, 3.2899, 1.0624, 1.4977, 4.3810,
        4.2225, 3.1540, 4.0191, 3.4060, 2.0160, 1.2814, 2.9608, 1.1513, 1.9530,
        0.9872, 3.9914, 4.9865, 3.0477, 1.9701, 2.5473, 4.6173, 0.5252, 3.4813,
        4.7750, 3.5971, 4.1422, 5.2524, 2.0691, 2.0039, 2.1420, 5.0156, 1.8481,
        0.7781, 2.2183, 2.6443, 2.2889, 4.3373, 2.2449, 4.5594, 3.8624, 5.2333,
        2.1537, 4.7373, 5.0298, 2.1322, 2.7047, 4.2308, 2.4907, 2.3259, 4.7678,
        5.0266, 5.2597, 4.7697, 0.9911, 1.0440, 2.7818, 1.1585, 2.1967, 4.4422,
        3.4151, 4.8498, 2.8759, 4.0001, 1.8111, 2.3611, 1.9162, 4.3893, 2.4916,
        2.0815, 3.9944, 5.5013, 4.6856, 4.9904, 4.9865, 4.0133, 1.7178, 2.2895,
        1.9270, 2.6042, 3.5529, 2.0969, 4.9558, 2.9138, 3.0347, 1.6650, 2.2754,
        3.7474, 0.9570, 2.3175, 4.1652, 4.4857, 5.1715, 5.4181, 3.7923, 4.0080,
        3.5960, 2.1456, 2.9088, 3.2796, 1.2589, 4.6664, 2.3447, 3.6577, 2.8478,
        2.8383, 2.9852, 0.9179, 3.0814, 1.2291, 0.7621, 0.8185, 2.3877, 1.1364,
        2.8636, 3.2811, 4.5790, 0.7829, 1.2192, 0.8060, 1.2341, 3.9681, 1.1422,
        4.9521, 5.1383, 2.8294, 3.6377, 3.8734, 3.5638, 2.9749, 1.5051, 1.4258,
        0.5846, 1.0630, 2.7671, 1.7912, 2.8734, 1.8916, 3.9858, 2.5439, 2.6786,
        1.6704, 1.3576, 1.7701, 2.0514, 5.1632, 2.6842, 4.8280, 1.1577, 2.9727,
        3.3256, 2.7693, 2.6430, 4.0140, 2.2755, 2.1973, 4.6792, 5.3522, 2.8058,
        4.9664, 5.2983, 4.7749, 3.8326, 2.4944, 4.2989, 2.8784, 2.0605, 2.8428,
        3.1081, 1.0411, 1.9078, 3.9594], grad_fn=<SumBackward1>)
# Now, let's predict the ratings for our df dataframe

A = torch.from_numpy(df.values)
A.shape
torch.Size([45, 10])
mask = ~torch.isnan(A)

# Get the indices of the non-NaN values
i, j = torch.where(mask)

# Get the values of the non-NaN values
v = A[mask]

# Store in PyTorch tensors
users = i.to(torch.int64)
movies = j.to(torch.int64)
ratings = v.to(torch.float32)
pd.DataFrame({'user': users, 'movie': movies, 'rating': ratings})
user movie rating
0 0 0 4.0
1 0 1 5.0
2 0 2 4.0
3 0 3 4.0
4 0 4 5.0
... ... ... ...
371 44 3 4.0
372 44 4 4.0
373 44 5 4.0
374 44 6 4.0
375 44 7 5.0

376 rows × 3 columns

# Fit the Matrix Factorization model
n_users = A.shape[0]
n_movies = A.shape[1]
model = MatrixFactorization(n_users, n_movies, 4)
optimizer = optim.Adam(model.parameters(), lr=0.01)

for i in range(1000):
    # Compute the loss
    pred = model(users, movies)
    loss = F.mse_loss(pred, ratings)
    
    # Zero the gradients
    optimizer.zero_grad()
    
    # Backpropagate
    loss.backward()
    
    # Update the parameters
    optimizer.step()
    
    # Print the loss
    if i % 100 == 0:
        print(loss.item())
19.889324188232422
3.1148574352264404
0.6727441549301147
0.5543633103370667
0.5081750750541687
0.4629250764846802
0.4147825837135315
0.36878159642219543
0.32987719774246216
0.29975879192352295
# Now, let us predict the ratings for any user and movie from df for which we already have the ratings

username = 'Dhruv'
movie = 'The Dark Knight'

# Get the user and movie indices
user_idx = df.index.get_loc(username)
movie_idx = df.columns.get_loc(movie)

# Predict the rating
pred = model(torch.tensor([user_idx]), torch.tensor([movie_idx]))
pred.item(), df.loc[username, movie]
(5.259384632110596, 5.0)
df.loc[username]
Sholay                      NaN
Swades (We The People)      NaN
The Matrix (I)              5.0
Interstellar                5.0
Dangal                      3.0
Taare Zameen Par            NaN
Shawshank Redemption        5.0
The Dark Knight             5.0
Notting Hill                4.0
Uri: The Surgical Strike    5.0
Name: Dhruv, dtype: float64
# Now, let us predict the ratings for any user and movie from df for which we do not have the ratings

username = 'Dhruv'
movie = 'Sholay'

# Get the user and movie indices
user_idx = df.index.get_loc(username)
movie_idx = df.columns.get_loc(movie)

# Predict the rating
pred = model(torch.tensor([user_idx]), torch.tensor([movie_idx]))
pred, df.loc[username, movie]
(tensor([3.7885], grad_fn=<SumBackward1>), nan)
# Complete the matrix
with torch.no_grad():
    completed_matrix = pd.DataFrame(model.user_factors.weight @ model.movie_factors.weight.t(), index=df.index, columns=df.columns)
    # round to nearest integer
    completed_matrix = completed_matrix.round()
completed_matrix.head()
Sholay Swades (We The People) The Matrix (I) Interstellar Dangal Taare Zameen Par Shawshank Redemption The Dark Knight Notting Hill Uri: The Surgical Strike
Your name
Nipun 5.0 4.0 4.0 5.0 4.0 5.0 5.0 4.0 4.0 5.0
Gautam Vashishtha 3.0 3.0 4.0 4.0 2.0 3.0 4.0 5.0 4.0 3.0
Eshan Gujarathi 4.0 4.0 5.0 5.0 4.0 4.0 5.0 5.0 4.0 4.0
Sai Krishna Avula 4.0 4.0 3.0 4.0 4.0 6.0 4.0 3.0 3.0 4.0
Ankit Yadav 3.0 2.0 3.0 4.0 3.0 5.0 4.0 3.0 3.0 4.0
df.head()
Sholay Swades (We The People) The Matrix (I) Interstellar Dangal Taare Zameen Par Shawshank Redemption The Dark Knight Notting Hill Uri: The Surgical Strike
Your name
Nipun 4.0 5.0 4.0 4.0 5.0 5.0 4.0 5.0 4.0 5.0
Gautam Vashishtha 3.0 4.0 4.0 5.0 3.0 1.0 5.0 5.0 4.0 3.0
Eshan Gujarathi 4.0 NaN 5.0 5.0 4.0 5.0 5.0 5.0 NaN 4.0
Sai Krishna Avula 5.0 3.0 3.0 4.0 4.0 5.0 5.0 3.0 3.0 4.0
Ankit Yadav 3.0 3.0 2.0 5.0 2.0 5.0 5.0 3.0 3.0 4.0