import sklearn
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
This is a work in progress. I will be adding more content to this post in the coming days.
Reference: https://scikit-learn.org/stable/auto_examples/calibration/plot_calibration_multiclass.html#sphx-glr-auto-examples-calibration-plot-calibration-multiclass-py
import numpy as np
from sklearn.datasets import make_blobs
0)
np.random.seed(
= make_blobs(
X, y =2000, n_features=2, centers=3, random_state=42, cluster_std=5.0
n_samples
)= X[:600], y[:600]
X_train, y_train = X[600:1000], y[600:1000]
X_valid, y_valid = X[:1000], y[:1000]
X_train_valid, y_train_valid = X[1000:], y[1000:] X_test, y_test
# Scater plot showing different classes in different colors
0], X[:, 1], c=y ,alpha=0.7) plt.scatter(X[:,
from sklearn.linear_model import LogisticRegression
= LogisticRegression()
lr lr.fit(X_train, y_train)
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
= pd.DataFrame(lr.predict_proba(X_valid))
prob_df = lr.classes_
prob_df.columns prob_df.head()
0 | 1 | 2 | |
---|---|---|---|
0 | 0.014323 | 0.959135 | 0.026542 |
1 | 0.000326 | 0.004617 | 0.995057 |
2 | 0.667887 | 0.322486 | 0.009627 |
3 | 0.953779 | 0.043703 | 0.002518 |
4 | 0.000029 | 0.000130 | 0.999841 |
400), y_valid]).quantile(0.1) pd.Series(prob_df.values[np.arange(
0.3934260593598625
# Get the predicted probability for the correct class for each sample
y_valid
array([1, 2, 0, 0, 2, 2, 2, 1, 1, 2, 1, 1, 0, 1, 2, 0, 0, 1, 2, 1, 1, 2,
1, 0, 0, 2, 0, 0, 1, 2, 0, 1, 2, 0, 0, 2, 1, 2, 0, 1, 1, 0, 0, 1,
0, 0, 2, 2, 1, 1, 0, 0, 0, 1, 2, 2, 2, 1, 0, 1, 1, 1, 2, 0, 1, 1,
0, 1, 1, 2, 2, 1, 0, 1, 1, 0, 2, 1, 2, 2, 2, 1, 2, 1, 1, 2, 2, 1,
0, 1, 0, 1, 2, 2, 0, 0, 0, 1, 0, 1, 2, 2, 0, 2, 0, 2, 1, 0, 0, 1,
2, 2, 2, 1, 0, 2, 2, 0, 0, 2, 0, 1, 2, 0, 1, 1, 2, 2, 1, 1, 2, 2,
0, 0, 0, 0, 0, 0, 2, 0, 1, 1, 1, 2, 0, 2, 0, 1, 1, 0, 2, 0, 1, 0,
1, 0, 2, 2, 0, 0, 2, 1, 0, 2, 0, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 2,
1, 0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 1, 2, 0, 2, 0, 1, 2, 1, 1, 0, 0,
2, 0, 1, 1, 1, 1, 0, 2, 2, 1, 1, 1, 0, 2, 1, 2, 2, 2, 1, 0, 0, 2,
0, 0, 2, 2, 0, 2, 2, 2, 0, 1, 2, 0, 2, 0, 1, 0, 2, 2, 2, 1, 0, 1,
1, 2, 2, 0, 2, 2, 2, 2, 0, 2, 1, 0, 1, 0, 1, 1, 1, 0, 2, 0, 2, 1,
0, 1, 0, 1, 2, 0, 1, 2, 2, 2, 0, 1, 0, 1, 0, 1, 1, 2, 1, 1, 2, 0,
1, 0, 1, 2, 0, 1, 0, 1, 1, 2, 0, 1, 1, 0, 2, 2, 1, 2, 0, 1, 1, 2,
1, 2, 2, 0, 2, 2, 2, 0, 1, 2, 1, 0, 2, 1, 2, 0, 2, 1, 0, 1, 1, 2,
0, 1, 2, 0, 2, 1, 2, 0, 0, 0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 2, 0, 1,
0, 1, 0, 2, 1, 1, 1, 2, 2, 0, 2, 2, 2, 1, 2, 1, 2, 2, 0, 0, 2, 2,
0, 1, 1, 0, 1, 2, 0, 1, 1, 2, 1, 0, 1, 0, 0, 2, 2, 0, 0, 1, 0, 0,
2, 2, 2, 2])
# Get the predicted probability for the correct class for each sample
1 | 2 | 0 | 0 | 2 | 2 | 2 | 1 | 1 | 2 | ... | 2 | 0 | 0 | 1 | 0 | 0 | 2 | 2 | 2 | 2 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.959135 | 0.026542 | 0.014323 | 0.014323 | 0.026542 | 0.026542 | 0.026542 | 0.959135 | 0.959135 | 0.026542 | ... | 0.026542 | 0.014323 | 0.014323 | 0.959135 | 0.014323 | 0.014323 | 0.026542 | 0.026542 | 0.026542 | 0.026542 |
1 | 0.004617 | 0.995057 | 0.000326 | 0.000326 | 0.995057 | 0.995057 | 0.995057 | 0.004617 | 0.004617 | 0.995057 | ... | 0.995057 | 0.000326 | 0.000326 | 0.004617 | 0.000326 | 0.000326 | 0.995057 | 0.995057 | 0.995057 | 0.995057 |
2 | 0.322486 | 0.009627 | 0.667887 | 0.667887 | 0.009627 | 0.009627 | 0.009627 | 0.322486 | 0.322486 | 0.009627 | ... | 0.009627 | 0.667887 | 0.667887 | 0.322486 | 0.667887 | 0.667887 | 0.009627 | 0.009627 | 0.009627 | 0.009627 |
3 | 0.043703 | 0.002518 | 0.953779 | 0.953779 | 0.002518 | 0.002518 | 0.002518 | 0.043703 | 0.043703 | 0.002518 | ... | 0.002518 | 0.953779 | 0.953779 | 0.043703 | 0.953779 | 0.953779 | 0.002518 | 0.002518 | 0.002518 | 0.002518 |
4 | 0.000130 | 0.999841 | 0.000029 | 0.000029 | 0.999841 | 0.999841 | 0.999841 | 0.000130 | 0.000130 | 0.999841 | ... | 0.999841 | 0.000029 | 0.000029 | 0.000130 | 0.000029 | 0.000029 | 0.999841 | 0.999841 | 0.999841 | 0.999841 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
395 | 0.253215 | 0.038669 | 0.708116 | 0.708116 | 0.038669 | 0.038669 | 0.038669 | 0.253215 | 0.253215 | 0.038669 | ... | 0.038669 | 0.708116 | 0.708116 | 0.253215 | 0.708116 | 0.708116 | 0.038669 | 0.038669 | 0.038669 | 0.038669 |
396 | 0.000339 | 0.999576 | 0.000086 | 0.000086 | 0.999576 | 0.999576 | 0.999576 | 0.000339 | 0.000339 | 0.999576 | ... | 0.999576 | 0.000086 | 0.000086 | 0.000339 | 0.000086 | 0.000086 | 0.999576 | 0.999576 | 0.999576 | 0.999576 |
397 | 0.019843 | 0.980018 | 0.000139 | 0.000139 | 0.980018 | 0.980018 | 0.980018 | 0.019843 | 0.019843 | 0.980018 | ... | 0.980018 | 0.000139 | 0.000139 | 0.019843 | 0.000139 | 0.000139 | 0.980018 | 0.980018 | 0.980018 | 0.980018 |
398 | 0.000094 | 0.999780 | 0.000126 | 0.000126 | 0.999780 | 0.999780 | 0.999780 | 0.000094 | 0.000094 | 0.999780 | ... | 0.999780 | 0.000126 | 0.000126 | 0.000094 | 0.000126 | 0.000126 | 0.999780 | 0.999780 | 0.999780 | 0.999780 |
399 | 0.000133 | 0.999776 | 0.000092 | 0.000092 | 0.999776 | 0.999776 | 0.999776 | 0.000133 | 0.000133 | 0.999776 | ... | 0.999776 | 0.000092 | 0.000092 | 0.000133 | 0.000092 | 0.000092 | 0.999776 | 0.999776 | 0.999776 | 0.999776 |
400 rows × 400 columns