Calibration

Probability Calibration
ml
Author

Nipun Batra

Published

October 27, 2022

import numpy as np
import sklearn
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

https://towardsdatascience.com/introduction-to-reliability-diagrams-for-probability-calibration-ed785b3f5d44

p = np.array([0.9, 0.2, 0.7, 0.4, 0.8, 0.1, 0.2, 0.8, 0.5, 0.9])
true_labels = np.ones_like(p)
true_labels[[1, 6, 7, 8]] = 0
true_labels
array([1., 0., 1., 1., 1., 1., 0., 0., 0., 1.])
num_splits = 3
splits_arr = np.linspace(0, 1, num_splits + 1)
splits = [(x, y) for (x, y) in zip(splits_arr[:-1], splits_arr[1:])]
splits
[(0.0, 0.3333333333333333),
 (0.3333333333333333, 0.6666666666666666),
 (0.6666666666666666, 1.0)]
pd.cut(pd.Series(p), bins=splits_arr)
0      (0.667, 1.0]
1      (0.0, 0.333]
2      (0.667, 1.0]
3    (0.333, 0.667]
4      (0.667, 1.0]
5      (0.0, 0.333]
6      (0.0, 0.333]
7      (0.667, 1.0]
8    (0.333, 0.667]
9      (0.667, 1.0]
dtype: category
Categories (3, interval[float64, right]): [(0.0, 0.333] < (0.333, 0.667] < (0.667, 1.0]]
splits = np.digitize(p, splits_arr)
splits
array([3, 1, 3, 2, 3, 1, 1, 3, 2, 3])
p_group = {}
labels_pos = {}
for group in np.unique(splits):
    p_group[group] = p[splits==group]
    #frac_pos[group] = true_labels[splits==group].sum()*1.0/len(p_group[group])
    labels_pos[group] = true_labels[splits==group]
    #print(np.arange(10)[splits==group])
p_group
{1: array([0.2, 0.1, 0.2]),
 2: array([0.4, 0.5]),
 3: array([0.9, 0.7, 0.8, 0.8, 0.9])}
labels_pos
{1: array([0., 1., 0.]), 2: array([1., 0.]), 3: array([1., 1., 1., 0., 1.])}
p_group_mean = {k:np.mean(v) for k, v in p_group.items()}
p_group_mean
{1: 0.16666666666666666, 2: 0.45, 3: 0.8200000000000001}
fracs = {k:np.sum(v)*1.0/len(v) for k, v in labels_pos.items()}
fracs
{1: 0.3333333333333333, 2: 0.5, 3: 0.8}
plt.plot(p_group_mean.values(), fracs.values(), marker='*')
plt.xlim((-0.05, 1.05))
plt.ylim((-0.05, 1.05))
plt.gca().set_aspect("equal")
plt.xlabel("p")
plt.ylabel("Relative frequency")

plt.plot([0, 1], [0, 1], color='k', ls='--', label='Ideal')
plt.legend()
<matplotlib.legend.Legend at 0x14ae4f1c0>

Let us wrap into a function

def calib_curve(true, pred, n_bins = 10):
    splits_arr = np.linspace(0, 1, n_bins + 1)
    splits = np.digitize(pred, splits_arr)
    p_group = {}
    labels_pos = {}
    for group in np.unique(splits):
        p_group[group] = pred[splits==group]
        labels_pos[group] = true[splits==group]
    p_group_mean = {k:np.mean(v) for k, v in p_group.items()}
    fracs = {k:np.sum(v)*1.0/len(v) for k, v in labels_pos.items()}
    counts = np.array([len(v) for v in labels_pos.values()])
    return  np.array(list(p_group_mean.values())), np.array(list(fracs.values())), counts
from sklearn.calibration import calibration_curve, CalibrationDisplay

prob_true, prob_pred = calibration_curve(true_labels, p, n_bins=3)
prob_true
array([0.33333333, 0.5       , 0.8       ])
prob_pred
array([0.16666667, 0.45      , 0.82      ])
p_ours, p_hat_ours, count = calib_curve(true_labels, p, 3)
p_ours
array([0.16666667, 0.45      , 0.82      ])

Expected Calibration Error

(np.abs(p_ours-p_hat_ours)*count).mean()
0.23333333333333336
from sklearn.datasets import make_classification
X, y = make_classification(n_features=2, n_informative=2, n_redundant=0, random_state=0)
plt.scatter(X[:, 0], X[:, 1], c = y)
<matplotlib.collections.PathCollection at 0x14bd53160>

from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
lr.fit(X, y)
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
display = CalibrationDisplay.from_estimator(
        lr,
        X,
        y,
        n_bins=11,
)

pred_p = lr.predict_proba(X)[:, 1]
probs, fractions, counts  = calib_curve(y, pred_p, 11)
plt.plot(probs, fractions, marker='^')
plt.xlabel("p")
plt.ylabel("Relative frequency")

plt.plot([0, 1], [0, 1], color='k', ls='--', label='Ideal')
plt.legend()
<matplotlib.legend.Legend at 0x14be53ca0>

plt.hist(pred_p);

(np.abs(probs-fractions)*counts).mean()
0.9530316314463302
counts
array([18, 14, 13,  4,  2,  4,  2,  2,  6,  7, 28])