import torch
import torchvision
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from latexify import latexify
import seaborn as sns
%matplotlib inline
# config retina
%config InlineBackend.figure_format = 'retina'
Traditional Programming vs Machine Learning
ML
Tutorial
# Set device (CPU or GPU)
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device device
device(type='cuda')
# Load MNIST dataset
= torchvision.datasets.MNIST('../datasets', train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_train = torchvision.datasets.MNIST('../datasets', train=False, transform=torchvision.transforms.ToTensor(), download=True) mnist_test
# Function to show a digit marking 28x28 grid with arrows pointing to random pixels
def show_digit_with_arrows(digit, label=None):
= plt.subplots()
fig, ax = digit.numpy().reshape(28, 28)
digit
# Display the digit
='gray')
ax.imshow(digit, cmap
# Add gridlines corresponding to 28 rows and columns
for i in range(1, 28):
='white', linewidth=0.5)
ax.axhline(i, color='white', linewidth=0.5)
ax.axvline(i, color
# Display label if available
if label is not None:
f'Label: {label}')
ax.set_title(return fig, ax
= 2
index # Show a random digit with arrows pointing to random 10 pixels
= show_digit_with_arrows(*mnist_train[index])
fig, ax # save figure
"../figures/mnist.pdf", bbox_inches='tight') fig.savefig(
# Find indices of digit 4 in the training set
= torch.where(torch.tensor(mnist_train.targets) == 4)[0]
digit_4_indices_train = torch.where(torch.tensor(mnist_test.targets) == 4)[0]
digit_4_indices_test
print(f"Indices of digit 4 in Train dataset: {digit_4_indices_train}")
print(f"Number of digit 4 images in training set: {len(digit_4_indices_train)}\n")
Indices of digit 4 in Train dataset: tensor([ 2, 9, 20, ..., 59943, 59951, 59975])
Number of digit 4 images in training set: 5842
/tmp/ipykernel_1361527/214778730.py:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
digit_4_indices_train = torch.where(torch.tensor(mnist_train.targets) == 4)[0]
/tmp/ipykernel_1361527/214778730.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
digit_4_indices_test = torch.where(torch.tensor(mnist_test.targets) == 4)[0]
=7, fig_height=5)
latexify(fig_width
for i in range(15):
3, 5, i+1)
plt.subplot(='gray')
plt.imshow(mnist_train.data[digit_4_indices_train[i]], cmapf"idx: {digit_4_indices_train[i]}")
plt.title('off') plt.axis(
# Select a sample from the training set
= 60
sample_idx_1 = mnist_train[sample_idx_1]
image, label ='gray')
plt.imshow(image.squeeze().numpy(), cmapf"Label: {label}")
plt.title( plt.show()
# Function to extract edges based on intensity threshold
def extract_edges(image, threshold=0.1):
'''
Input:
image: torch.tensor of shape (28, 28)
threshold: (float) the minimum intensity value to be considered as white pixel
'''
= torch.zeros_like(image)
edges
# converting all the pixels with intensity greater than threshold to white
> threshold] = 1.0
edges[image return edges
# Creating rules based upon one image
= extract_edges(image)
edges
0, :, :], cmap='gray')
plt.imshow(edges[
# finding areas of edges
= edges[:, 4:15, 3:12]
left_edge_train = edges[:, 4:19, 17:24]
upper_right_edge_train = edges[:, 14:20, 5:25]
middle_edge_train = edges[:, 17:24, 18:24]
lower_right_edge_train
# R1 (4-15, 3-12)
= plt.Rectangle((3, 4), 9, 11, linewidth=1, edgecolor='r', facecolor='none')
r1 = plt.Rectangle((17, 4), 7, 15, linewidth=1, edgecolor='g', facecolor='none')
r2 = plt.Rectangle((5, 14), 20, 6, linewidth=1, edgecolor='b', facecolor='none')
r3 = plt.Rectangle((18, 17), 6, 7, linewidth=1, edgecolor='y', facecolor='none')
r4 for rect in [r1, r2, r3, r4]:
plt.gca().add_patch(rect)
# creat a subplot 2 rows by 2 columns
= plt.subplots(2, 2, figsize=(8, 10))
fig, axs
# plotting the images
0, 0].imshow(left_edge_train.squeeze().numpy(), cmap='gray')
axs[0, 1].imshow(upper_right_edge_train.squeeze().numpy(), cmap='gray')
axs[1, 0].imshow(middle_edge_train.squeeze().numpy(), cmap='gray')
axs[1, 1].imshow(lower_right_edge_train.squeeze().numpy(), cmap='gray')
axs[
0, 0].set_title(f"Left Edge\nWhite pixels: {int(left_edge_train.sum())}/{left_edge_train.numel()}")
axs[0, 1].set_title(f"Upper Right Edge\nWhite pixels: {int(upper_right_edge_train.sum())}/{upper_right_edge_train.numel()}")
axs[1, 0].set_title(f"Middle Edge\nWhite pixels: {int(middle_edge_train.sum())}/{middle_edge_train.numel()}")
axs[1, 1].set_title(f"Lower Right Edge\nWhite pixels: {int(lower_right_edge_train.sum())}/{lower_right_edge_train.numel()}")
axs[
plt.show()
# Rule-based digit classifier for digit 4
def rule_based_classifier(image):
# Extract edges
= extract_edges(image)
edges
# Define rules for digit 4 based on the edges of the digit
= edges[:, 4:15, 3:12]
left_edge = edges[:, 4:19, 17:24]
upper_right_edge = edges[:, 14:20, 5:25]
middle_edge = edges[:, 17:24, 18:24]
lower_right_edge
# Check if all required edges are present by checking the number of white pixels for each edge.
# The number of white pixels for each edge is 'sub' less than the number of pixels in the edge for the above take digit.
= 10
sub if torch.sum(left_edge) > left_edge_train.sum() - sub and torch.sum(upper_right_edge) > upper_right_edge_train.sum() - sub and torch.sum(middle_edge) > middle_edge_train.sum() - sub and torch.sum(lower_right_edge) > lower_right_edge_train.sum() - sub:
return 4
else:
return -1 # -1 indicates that the digit is not 4
# Display some wrongly classified images
= [6, 19, 25, 200]
indices # define image size
=(14, 3))
plt.figure(figsize
for i in range(4):
1, 4, i+1)
plt.subplot(= mnist_test[indices[i]]
image, label = rule_based_classifier(image)
pred = pred if pred != -1 else "Not 4"
pred f"Label: {label}, Predicted: {pred}")
plt.title(='gray') plt.imshow(image.squeeze().numpy(), cmap
# Evaluating the rule-based classifier
= 0
count = 0
count_4 for i, (image, label) in enumerate(mnist_test):
= rule_based_classifier(image)
classification if (classification == 4 and label == 4) or (classification == -1 and label != 4):
+= 1
count if (classification == 4 and label == 4):
+= 1
count_4
= count * 100/ len(mnist_test)
accuracy_rule = count_4 * 100/ len(digit_4_indices_test)
percentage_TP_rule print(f"Accuracy of the rule-based classifier: {accuracy_rule} %")
print(f"Percentage of 4s actually classified as 4 (percentage of True Positives): {percentage_TP_rule:.3} %")
Accuracy of the rule-based classifier: 88.56 %
Percentage of 4s actually classified as 4 (percentage of True Positives): 4.28 %
Note: As per rules, it is predicting most of the digits as non-4 for most of the digits. And since the number of non-4 digits are much more compared to number of instances of the digit 4, the accuracy is high. But this is not a good model as it is not predicting the digit 4 correctly.
ML based approach
# Flatten the images and convert the labels to 4 and -1 for binary classification problem
= mnist_train.data.numpy().reshape((len(mnist_train), -1))
X_train = np.where(mnist_train.targets.numpy() == 4, 4, -1)
y_train
= mnist_test.data.numpy().reshape((len(mnist_test), -1))
X_test = np.where(mnist_test.targets.numpy() == 4, 4, -1) y_test
# Create and train the MLP model
= MLPClassifier(hidden_layer_sizes=(100,), max_iter=20, random_state=42)
mlp_model mlp_model.fit(X_train, y_train)
/home/nipun.batra/miniforge3/lib/python3.9/site-packages/sklearn/neural_network/_multilayer_perceptron.py:691: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (20) reached and the optimization hasn't converged yet.
warnings.warn(
MLPClassifier(max_iter=20, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
MLPClassifier(max_iter=20, random_state=42)
# Evaluate the model
= mlp_model.predict(X_test)
y_pred = accuracy_score(y_test,( y_pred))
accuracy_ML = accuracy_ML * 100
accuracy_ML = np.sum((y_test == 4) & (y_pred == 4)) * 100 / len(digit_4_indices_test)
percentage_TP_ML print(f'Test Accuracy: {accuracy_ML:.2f}%')
print(f"Percentage of 4s actually classified as 4 (percentage of True Positives): {percentage_TP_ML:.3} %")
Test Accuracy: 99.47%
Percentage of 4s actually classified as 4 (percentage of True Positives): 97.1 %
Comparison of Rule-based system and ML based system
# Categories for the bar plot
= ['Accuracy', 'True Positive Percentage']
categories
# Values for the rule-based classifier
= [accuracy_rule, percentage_TP_rule]
rule_based_values
# Values for the MLP classifier
= [accuracy_ML, percentage_TP_ML]
mlp_values
# Bar width
= 0.35
bar_width
# X-axis positions for the bars
= range(len(categories))
index
# Plotting the bar plot
= plt.subplots(figsize=(9, 5))
fig, ax = ax.bar(index, rule_based_values, bar_width, label='Rule-Based Classifier')
bar1 = ax.bar([i + bar_width for i in index], mlp_values, bar_width, label='MLP Classifier')
bar2
# Adding labels, title, and legend
'Metrics')
ax.set_xlabel('Percentage / Accuracy')
ax.set_ylabel('Comparison of Classifiers')
ax.set_title(+ bar_width / 2 for i in index])
ax.set_xticks([i
ax.set_xticklabels(categories)
ax.legend()
# Display the values on top of the bars
for bar in bar1 + bar2:
= bar.get_height()
yval + bar.get_width()/2, yval, round(yval, 2), ha='center', va='bottom')
plt.text(bar.get_x()
# Show the plot
plt.show()