import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
# Retina mode
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
CDF for discrete random variables
ML
# Set random seed
42) torch.manual_seed(
<torch._C.Generator at 0x135180ab0>
= [0.1, 0.4, 0.2, 0.3]
probs = torch.distributions.Categorical(probs=torch.tensor(probs)) dist
dist
Categorical(probs: torch.Size([4]))
1000]))).value_counts().sort_index() pd.Series(dist.sample(torch.Size([
0 90
1 409
2 187
3 314
Name: count, dtype: int64
# Generate 2D input features
= 1000
n_samples = torch.distributions.Uniform(0, 5).sample((n_samples, 1)) # Feature 1
X1 = torch.distributions.Uniform(0, 5).sample((n_samples, 1)) # Feature 2 X2
# True weights and bias
= 1.2, -0.8, -2.5 w1, w2, b
# Compute logits and apply sigmoid
= w1 * X1 + w2 * X2 + b
logits = torch.sigmoid(logits) # Probabilities prob_Y
prob_Y
tensor([[0.4690],
[0.8697],
[0.2377],
[0.6852],
[0.2472],
[0.5074],
[0.0922],
[0.7336],
[0.8787],
[0.0406],
[0.6429],
[0.6637],
[0.8729],
[0.2455],
[0.4179],
[0.0462],
[0.9018],
[0.1551],
[0.0173],
[0.1364],
[0.2652],
[0.2519],
[0.2691],
[0.8527],
[0.0068],
[0.0701],
[0.0393],
[0.1062],
[0.1793],
[0.0372],
[0.5522],
[0.0402],
[0.7888],
[0.6758],
[0.0772],
[0.1758],
[0.1189],
[0.8798],
[0.2317],
[0.0974],
[0.7143],
[0.0202],
[0.4330],
[0.6808],
[0.7392],
[0.1329],
[0.4755],
[0.1541],
[0.3953],
[0.0859],
[0.0016],
[0.0859],
[0.0707],
[0.8696],
[0.0688],
[0.4360],
[0.4305],
[0.3274],
[0.7977],
[0.1249],
[0.5691],
[0.0089],
[0.0923],
[0.0116],
[0.7033],
[0.0389],
[0.9026],
[0.0113],
[0.0087],
[0.0079],
[0.4398],
[0.1479],
[0.7938],
[0.0074],
[0.0228],
[0.0089],
[0.0534],
[0.4375],
[0.7883],
[0.8199],
[0.1195],
[0.0536],
[0.7972],
[0.2862],
[0.5475],
[0.0290],
[0.1672],
[0.0067],
[0.1216],
[0.8010],
[0.1646],
[0.0328],
[0.8813],
[0.4557],
[0.0046],
[0.0383],
[0.9207],
[0.0155],
[0.0318],
[0.0032],
[0.1341],
[0.1285],
[0.3706],
[0.0035],
[0.3584],
[0.0085],
[0.0116],
[0.0246],
[0.0215],
[0.0143],
[0.9115],
[0.7215],
[0.5481],
[0.0028],
[0.0500],
[0.0420],
[0.4277],
[0.0035],
[0.0182],
[0.0180],
[0.3848],
[0.2725],
[0.1439],
[0.0863],
[0.0055],
[0.2094],
[0.0317],
[0.0028],
[0.5034],
[0.0586],
[0.9528],
[0.0766],
[0.0242],
[0.9202],
[0.0134],
[0.1375],
[0.0209],
[0.0018],
[0.6780],
[0.3229],
[0.3414],
[0.0243],
[0.5663],
[0.1166],
[0.1113],
[0.5040],
[0.7373],
[0.9401],
[0.5430],
[0.0139],
[0.7203],
[0.0124],
[0.0808],
[0.0160],
[0.9173],
[0.0179],
[0.1390],
[0.1397],
[0.7374],
[0.0071],
[0.7449],
[0.1415],
[0.0649],
[0.0029],
[0.7889],
[0.0241],
[0.1800],
[0.7164],
[0.6379],
[0.0617],
[0.0508],
[0.1972],
[0.6204],
[0.0813],
[0.0056],
[0.0244],
[0.0077],
[0.2261],
[0.3260],
[0.0143],
[0.2764],
[0.4105],
[0.6875],
[0.2774],
[0.4553],
[0.7025],
[0.4154],
[0.1725],
[0.0688],
[0.0773],
[0.2032],
[0.0181],
[0.2384],
[0.9317],
[0.4372],
[0.3650],
[0.4109],
[0.2889],
[0.3479],
[0.4569],
[0.0386],
[0.1222],
[0.8519],
[0.0917],
[0.2874],
[0.1328],
[0.0153],
[0.0168],
[0.0178],
[0.1498],
[0.1939],
[0.7941],
[0.0043],
[0.0068],
[0.2717],
[0.3926],
[0.0714],
[0.0912],
[0.2274],
[0.1253],
[0.1548],
[0.5459],
[0.0305],
[0.3112],
[0.0043],
[0.9017],
[0.2792],
[0.0876],
[0.0157],
[0.0389],
[0.0073],
[0.5978],
[0.3535],
[0.5908],
[0.1898],
[0.5297],
[0.2536],
[0.0130],
[0.0135],
[0.0027],
[0.6854],
[0.0522],
[0.0775],
[0.0632],
[0.0320],
[0.0212],
[0.1663],
[0.4284],
[0.1848],
[0.5393],
[0.0971],
[0.1523],
[0.1591],
[0.0594],
[0.1385],
[0.0020],
[0.6491],
[0.0076],
[0.3805],
[0.8249],
[0.5432],
[0.0042],
[0.8772],
[0.1644],
[0.0875],
[0.4281],
[0.3085],
[0.0614],
[0.1262],
[0.1158],
[0.4476],
[0.0059],
[0.7236],
[0.5122],
[0.9547],
[0.0268],
[0.0034],
[0.1269],
[0.4547],
[0.0058],
[0.3243],
[0.0574],
[0.1720],
[0.1343],
[0.2229],
[0.0219],
[0.0033],
[0.9415],
[0.5935],
[0.0582],
[0.0232],
[0.0202],
[0.0024],
[0.0227],
[0.1436],
[0.0084],
[0.1765],
[0.4679],
[0.0520],
[0.0276],
[0.3989],
[0.0865],
[0.0738],
[0.0310],
[0.0538],
[0.4719],
[0.0175],
[0.6643],
[0.0717],
[0.2207],
[0.6581],
[0.1212],
[0.0091],
[0.5362],
[0.3897],
[0.5126],
[0.7488],
[0.7894],
[0.3925],
[0.2507],
[0.0973],
[0.0892],
[0.6295],
[0.4406],
[0.1529],
[0.1186],
[0.0113],
[0.0021],
[0.1150],
[0.0482],
[0.8805],
[0.2454],
[0.0035],
[0.0044],
[0.8107],
[0.0069],
[0.0515],
[0.4856],
[0.0080],
[0.2625],
[0.2149],
[0.2317],
[0.0353],
[0.0090],
[0.0166],
[0.1212],
[0.9087],
[0.0258],
[0.1467],
[0.3867],
[0.3478],
[0.7219],
[0.5131],
[0.5104],
[0.0028],
[0.2580],
[0.8291],
[0.0395],
[0.0280],
[0.8736],
[0.2135],
[0.0247],
[0.0914],
[0.1309],
[0.2388],
[0.7147],
[0.5993],
[0.0881],
[0.0025],
[0.0412],
[0.7522],
[0.1615],
[0.0706],
[0.5131],
[0.0120],
[0.0252],
[0.0513],
[0.5335],
[0.3664],
[0.0042],
[0.2806],
[0.1637],
[0.1168],
[0.9396],
[0.0373],
[0.7879],
[0.0925],
[0.0391],
[0.6586],
[0.4116],
[0.1996],
[0.0494],
[0.2710],
[0.6425],
[0.1510],
[0.0312],
[0.0179],
[0.1739],
[0.6685],
[0.2275],
[0.1530],
[0.9160],
[0.4186],
[0.1890],
[0.2847],
[0.0327],
[0.0187],
[0.0628],
[0.1278],
[0.3404],
[0.1098],
[0.0375],
[0.7807],
[0.6327],
[0.7904],
[0.0507],
[0.1858],
[0.3649],
[0.0336],
[0.0142],
[0.1073],
[0.0164],
[0.0245],
[0.8529],
[0.5046],
[0.5838],
[0.1484],
[0.0258],
[0.1378],
[0.5009],
[0.1275],
[0.0463],
[0.0377],
[0.1537],
[0.1114],
[0.5183],
[0.0190],
[0.0583],
[0.3773],
[0.1361],
[0.2197],
[0.6813],
[0.2385],
[0.9250],
[0.9642],
[0.1617],
[0.0100],
[0.4980],
[0.9302],
[0.0457],
[0.1462],
[0.0613],
[0.0062],
[0.0611],
[0.2581],
[0.6507],
[0.0583],
[0.6322],
[0.0657],
[0.0992],
[0.5305],
[0.1600],
[0.3603],
[0.0488],
[0.6136],
[0.1191],
[0.0036],
[0.0577],
[0.4877],
[0.3272],
[0.7177],
[0.4588],
[0.7115],
[0.0364],
[0.2061],
[0.3108],
[0.0613],
[0.0117],
[0.0795],
[0.4676],
[0.5647],
[0.0231],
[0.1328],
[0.5421],
[0.0562],
[0.0614],
[0.3270],
[0.0339],
[0.8109],
[0.3840],
[0.5580],
[0.6581],
[0.5215],
[0.4812],
[0.4216],
[0.5801],
[0.8010],
[0.6829],
[0.0294],
[0.0095],
[0.0193],
[0.0555],
[0.2013],
[0.8739],
[0.5182],
[0.8575],
[0.0095],
[0.0300],
[0.0588],
[0.0581],
[0.4170],
[0.0025],
[0.1793],
[0.0060],
[0.0250],
[0.0052],
[0.4422],
[0.0052],
[0.3186],
[0.2248],
[0.0405],
[0.4789],
[0.0483],
[0.1727],
[0.5538],
[0.0335],
[0.4634],
[0.1829],
[0.2329],
[0.3739],
[0.0957],
[0.0156],
[0.3449],
[0.8196],
[0.6699],
[0.4111],
[0.0098],
[0.3198],
[0.0751],
[0.8296],
[0.0686],
[0.0258],
[0.0505],
[0.0505],
[0.6845],
[0.1587],
[0.3579],
[0.6190],
[0.9235],
[0.2053],
[0.1562],
[0.4789],
[0.0443],
[0.2962],
[0.0232],
[0.0190],
[0.5996],
[0.6140],
[0.7862],
[0.9355],
[0.4863],
[0.5655],
[0.0436],
[0.0308],
[0.0309],
[0.0590],
[0.0157],
[0.0254],
[0.0145],
[0.0126],
[0.0280],
[0.0055],
[0.0610],
[0.6977],
[0.6296],
[0.1038],
[0.1875],
[0.3234],
[0.0475],
[0.0674],
[0.7886],
[0.0208],
[0.6382],
[0.6671],
[0.2659],
[0.1304],
[0.0330],
[0.1543],
[0.1324],
[0.0923],
[0.6481],
[0.4782],
[0.0137],
[0.4810],
[0.0242],
[0.0645],
[0.2477],
[0.0396],
[0.8759],
[0.0614],
[0.2093],
[0.2825],
[0.0732],
[0.0819],
[0.0994],
[0.0042],
[0.0639],
[0.0714],
[0.6146],
[0.5548],
[0.5720],
[0.0173],
[0.0950],
[0.9247],
[0.0979],
[0.5093],
[0.3201],
[0.0249],
[0.0115],
[0.8606],
[0.3280],
[0.2417],
[0.0065],
[0.0185],
[0.0386],
[0.0495],
[0.0566],
[0.8324],
[0.0054],
[0.1783],
[0.0084],
[0.1398],
[0.5463],
[0.5990],
[0.2893],
[0.2422],
[0.0921],
[0.0408],
[0.0046],
[0.0630],
[0.2670],
[0.0066],
[0.9150],
[0.5427],
[0.0191],
[0.1254],
[0.0957],
[0.3615],
[0.1072],
[0.2503],
[0.7266],
[0.0921],
[0.5816],
[0.4543],
[0.1018],
[0.9215],
[0.0393],
[0.1537],
[0.0449],
[0.8551],
[0.0236],
[0.3884],
[0.0032],
[0.3583],
[0.7975],
[0.1741],
[0.7627],
[0.0478],
[0.4580],
[0.0336],
[0.2849],
[0.2970],
[0.9362],
[0.2849],
[0.0084],
[0.3964],
[0.0635],
[0.4038],
[0.0035],
[0.2056],
[0.7709],
[0.2959],
[0.0089],
[0.1077],
[0.0589],
[0.5707],
[0.0962],
[0.0793],
[0.0674],
[0.0335],
[0.7049],
[0.8910],
[0.0836],
[0.3845],
[0.5714],
[0.7929],
[0.6587],
[0.5552],
[0.0545],
[0.8679],
[0.8363],
[0.0228],
[0.0075],
[0.0453],
[0.1558],
[0.0073],
[0.1370],
[0.4195],
[0.0152],
[0.1042],
[0.0471],
[0.1086],
[0.6888],
[0.0051],
[0.4637],
[0.9148],
[0.3310],
[0.0220],
[0.0207],
[0.1868],
[0.0823],
[0.2680],
[0.8058],
[0.9286],
[0.1681],
[0.1451],
[0.1451],
[0.0671],
[0.0082],
[0.1195],
[0.8149],
[0.0024],
[0.0062],
[0.6743],
[0.4562],
[0.6410],
[0.0143],
[0.0469],
[0.7715],
[0.4464],
[0.5066],
[0.1339],
[0.5612],
[0.5935],
[0.0075],
[0.0167],
[0.5136],
[0.0040],
[0.0411],
[0.5657],
[0.0792],
[0.0777],
[0.1271],
[0.0341],
[0.6242],
[0.0441],
[0.0383],
[0.7237],
[0.0076],
[0.3616],
[0.1653],
[0.2611],
[0.4183],
[0.0433],
[0.0674],
[0.7030],
[0.2125],
[0.0033],
[0.0105],
[0.5878],
[0.2345],
[0.7199],
[0.1536],
[0.0522],
[0.0208],
[0.1033],
[0.0328],
[0.5261],
[0.4477],
[0.4163],
[0.0638],
[0.0860],
[0.7403],
[0.3582],
[0.0382],
[0.0144],
[0.3440],
[0.1347],
[0.2414],
[0.9490],
[0.0595],
[0.9154],
[0.8628],
[0.0031],
[0.6427],
[0.0366],
[0.0776],
[0.3921],
[0.0062],
[0.8658],
[0.2734],
[0.7700],
[0.0047],
[0.6233],
[0.0978],
[0.0369],
[0.4760],
[0.4007],
[0.3911],
[0.1022],
[0.0443],
[0.6452],
[0.5626],
[0.1089],
[0.5314],
[0.0162],
[0.6265],
[0.8584],
[0.2519],
[0.2426],
[0.8181],
[0.5241],
[0.0241],
[0.0273],
[0.0294],
[0.3048],
[0.2815],
[0.1653],
[0.6694],
[0.0173],
[0.4302],
[0.2016],
[0.0251],
[0.1268],
[0.0075],
[0.5129],
[0.0309],
[0.0977],
[0.0888],
[0.1067],
[0.5135],
[0.0031],
[0.0084],
[0.0379],
[0.3944],
[0.0491],
[0.1070],
[0.0736],
[0.1812],
[0.9447],
[0.1928],
[0.1918],
[0.0043],
[0.1300],
[0.0261],
[0.1692],
[0.8610],
[0.8094],
[0.9622],
[0.5234],
[0.1729],
[0.0978],
[0.5527],
[0.0086],
[0.2799],
[0.3916],
[0.6266],
[0.3144],
[0.0426],
[0.0029],
[0.0657],
[0.4495],
[0.1725],
[0.1724],
[0.4098],
[0.0512],
[0.8470],
[0.1267],
[0.5247],
[0.1115],
[0.7646],
[0.7213],
[0.0759],
[0.3569],
[0.3375],
[0.8672],
[0.7596],
[0.1475],
[0.0044],
[0.4780],
[0.0581],
[0.0251],
[0.4568],
[0.2024],
[0.7446],
[0.6320],
[0.0314],
[0.0765],
[0.0181],
[0.1966],
[0.5667],
[0.8577],
[0.9013],
[0.1643],
[0.0776],
[0.1805],
[0.0631],
[0.1894],
[0.3725],
[0.0367],
[0.1032],
[0.8510],
[0.3935],
[0.8488],
[0.9634],
[0.0439],
[0.0918],
[0.5004],
[0.0232],
[0.2922],
[0.0592],
[0.6902],
[0.0502],
[0.2931],
[0.5467],
[0.0902],
[0.0438],
[0.7579],
[0.0144],
[0.0310],
[0.9566],
[0.0983],
[0.1246],
[0.0175],
[0.3579],
[0.1215],
[0.0063],
[0.8350],
[0.0783],
[0.0673],
[0.0044],
[0.1932],
[0.0166],
[0.0619],
[0.0126],
[0.0981],
[0.1415],
[0.0975],
[0.0028],
[0.7217],
[0.1342],
[0.0666],
[0.8163],
[0.5906],
[0.8264],
[0.0062],
[0.1267],
[0.0576],
[0.4529],
[0.1712],
[0.0178],
[0.1451],
[0.0375],
[0.1448],
[0.1852],
[0.9204],
[0.6197],
[0.0180],
[0.3411],
[0.0445],
[0.1539],
[0.4608],
[0.5943],
[0.2118],
[0.8577],
[0.0662],
[0.2722],
[0.6509],
[0.5460],
[0.4635],
[0.2292],
[0.0513],
[0.2066],
[0.5374],
[0.0059],
[0.2224],
[0.1254],
[0.3754],
[0.5258],
[0.0939],
[0.1778],
[0.2772],
[0.0344],
[0.0388],
[0.9426],
[0.1209],
[0.2947],
[0.0870],
[0.4037],
[0.0158]])
# Sample class labels
= torch.distributions.Bernoulli(prob_Y).sample() Y
# Convert to NumPy for visualization
= X1.numpy(), X2.numpy(), Y.numpy()
X1_np, X2_np, Y_np
# Plot data points
== 0], X2_np[Y_np == 0], color="blue", label="Class 0", alpha=0.5)
plt.scatter(X1_np[Y_np == 1], X2_np[Y_np == 1], color="red", label="Class 1", alpha=0.5)
plt.scatter(X1_np[Y_np "Feature 1 (X1)")
plt.xlabel("Feature 2 (X2)")
plt.ylabel(
plt.legend()"Generated 2D Logistic Regression Data") plt.title(
Text(0.5, 1.0, 'Generated 2D Logistic Regression Data')
import torch.nn as nn
import torch.optim as optim
# Stack X1 and X2 into a single tensor
= torch.cat((X1, X2), dim=1)
X_train
# Define logistic regression model
class LogisticRegression2D(nn.Module):
def __init__(self):
super(LogisticRegression2D, self).__init__()
self.linear = nn.Linear(2, 1) # Two inputs, one output
def forward(self, x):
return torch.sigmoid(self.linear(x)) # Sigmoid activation
# Initialize model
= LogisticRegression2D()
model = nn.BCELoss() # Binary cross-entropy loss
loss_fn = optim.SGD(model.parameters(), lr=0.01)
optimizer
# Training loop
= 1000
epochs for epoch in range(epochs):
optimizer.zero_grad()= model(X_train) # Forward pass
Y_pred = loss_fn(Y_pred, Y.view(-1, 1)) # Compute loss
loss # Backpropagation
loss.backward() # Update weights
optimizer.step()
if epoch % 200 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# Extract learned parameters
= model.linear.weight[0].detach().numpy()
w1_learned, w2_learned = model.linear.bias[0].detach().numpy()
b_learned print(f"Learned Parameters: w1 = {w1_learned:.4f}, w2 = {w2_learned:.4f}, b = {b_learned:.4f}")
Epoch 0, Loss: 2.0864
Epoch 200, Loss: 0.4884
Epoch 400, Loss: 0.4536
Epoch 600, Loss: 0.4395
Epoch 800, Loss: 0.4318
Learned Parameters: w1 = 0.6339, w2 = -0.8716, b = -0.4170
def plot_decision_boundary(model, X1_np, X2_np, Y_np, w1_true, w2_true, b_true):
"""Plots the true and learned decision boundaries."""
# Generate mesh grid
= np.linspace(0, 5, 100)
x1_vals = np.linspace(0, 5, 100)
x2_vals = np.meshgrid(x1_vals, x2_vals)
X1_grid, X2_grid
# Compute model's learned decision boundary
with torch.no_grad():
= model(torch.tensor(np.c_[X1_grid.ravel(), X2_grid.ravel()], dtype=torch.float32))
Z = Z.view(X1_grid.shape).numpy()
Z
# Compute true decision boundary
= np.linspace(0,5, 100)
x1_boundary = - (w1_true / w2_true) * x1_boundary - (b_true / w2_true)
x2_boundary_true
# Plot data points
== 0], X2_np[Y_np == 0], color="blue", label="Class 0", alpha=0.5)
plt.scatter(X1_np[Y_np == 1], X2_np[Y_np == 1], color="red", label="Class 1", alpha=0.5)
plt.scatter(X1_np[Y_np
# Plot learned decision boundary
=[0.5], colors="black", linestyles="dashed", label="Learned Boundary")
plt.contour(X1_grid, X2_grid, Z, levels
# Plot true decision boundary
="green", linestyle="solid", label="True Boundary")
plt.plot(x1_boundary, x2_boundary_true, color
"Feature 1 (X1)")
plt.xlabel("Feature 2 (X2)")
plt.ylabel(
plt.legend()"Logistic Regression Decision Boundary")
plt.title(0, 5])
plt.ylim([
# Call the function with true and learned parameters
=1.2, w2_true=-0.8, b_true=-2.5) plot_decision_boundary(model, X1_np, X2_np, Y_np, w1_true
/var/folders/z8/gpvqr8mn3w9_f38byxhnsk780000gn/T/ipykernel_8863/319328462.py:23: UserWarning: The following kwargs were not used by contour: 'label'
plt.contour(X1_grid, X2_grid, Z, levels=[0.5], colors="black", linestyles="dashed", label="Learned Boundary")