import numpy as np
import time
import matplotlib.pyplot as plt
import pandas as pd
# Retina display
%config InlineBackend.figure_format = 'retina'
Basic Imports
= 8
log_size = 2**log_size
size
def create_data(size=2**10, random_seed=0):
np.random.seed(random_seed)= np.random.rand(size,size)
A = np.random.rand(size,size)
B return A, B
= create_data(size=size) A, B
# Naive implementation
def naive_multiply(A, B):
= np.zeros_like(A)
C for i in range(A.shape[0]):
for j in range(A.shape[1]):
for k in range(A.shape[0]):
+= A[i,k] * B[k,j]
C[i,j] return C
# Modify the Timer class to handle exceptions during the timing
class Timer:
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, *args):
self.end_time = time.time()
self.elapsed = self.end_time - self.start_time
# Define a function to run and time different matrix multiplications
def run_and_time_multiplication(function, n_times, *args, **kwargs):
= []
elapsed_times for _ in range(n_times):
with Timer() as timer:
= function(*args, **kwargs)
result
elapsed_times.append(timer.elapsed)return np.array(elapsed_times), result
# Number of times to run the timing code
= 10
n_times
# Time Naive multiplication
= run_and_time_multiplication(naive_multiply, n_times, A, B)
naive_times, C_naive print(f"Naive Multiplication Times: {naive_times.mean():0.3f} +/- {naive_times.std():0.3f}")
Naive Multiplication Times: 8.541 +/- 0.016
def divide_matrix_four_parts(A):
= A.shape[0]//2
n = A[:n,:n]
A11 = A[:n,n:]
A12 = A[n:,:n]
A21 = A[n:,n:]
A22 return A11, A12, A21, A22
= divide_matrix_four_parts(A) A11, A12, A21, A22
A.shape
(256, 256)
A11.shape
(128, 128)
def strassen_multiply(A, B, threshold=32):
# if A and B are threshold X threshold matrices directly multiply them
if A.shape[0] <= threshold:
return naive_multiply(A, B)
else:
= divide_matrix_four_parts(A)
A11, A12, A21, A22 = divide_matrix_four_parts(B)
B11, B12, B21, B22 = strassen_multiply(A11 + A22, B11 + B22)
M1 = strassen_multiply(A21 + A22, B11)
M2 = strassen_multiply(A11, B12 - B22)
M3 = strassen_multiply(A22, B21 - B11)
M4 = strassen_multiply(A11 + A12, B22)
M5 = strassen_multiply(A21 - A11, B11 + B12)
M6 = strassen_multiply(A12 - A22, B21 + B22)
M7 = M1 + M4 - M5 + M7
C11 = M3 + M5
C12 = M2 + M4
C21 = M1 - M2 + M3 + M6
C22 = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
C return C
# Time Strassen multiplication with thresholds 8, 16, 32
= [8, 16, 32]
thresholds
= {}
strassen_times = {}
strassen_results for threshold in thresholds:
= run_and_time_multiplication(strassen_multiply, n_times, A, B, threshold=threshold)
strassen_times[threshold], strassen_results[threshold] print(f"Strassen Multiplication Times (threshold={threshold}): {strassen_times[threshold].mean():0.3f} +/- {strassen_times[threshold].std():0.3f}")
Strassen Multiplication Times (threshold=8): 5.716 +/- 0.082
Strassen Multiplication Times (threshold=16): 5.767 +/- 0.016
Strassen Multiplication Times (threshold=32): 5.774 +/- 0.005
# Plot the results of the timing experiments as bar plot with mean and standard deviation
=(10,5))
plt.figure(figsize= {}
df "naive"] = {"mean": naive_times.mean(), "std": naive_times.std()}
df[for threshold in thresholds:
f"strassen \n(threshold={threshold})"] = {"mean": strassen_times[threshold].mean(), "std": strassen_times[threshold].std()}
df[= pd.DataFrame(df).T
df
"mean"].plot(kind="bar", yerr=df["std"], capsize=5, rot=0)
df["Time (s)") plt.ylabel(
Text(0, 0.5, 'Time (s)')
# Directly multiply A and B usung numpy
= run_and_time_multiplication(np.matmul, n_times, A, B)
numpy_times, C_numpy print(f"NumPy Multiplication Times: {numpy_times.mean():0.3f} +/- {numpy_times.std():0.3f}")
NumPy Multiplication Times: 0.001 +/- 0.000