Active Learning


    from astra.torch.models import ResNetClassifier
    %pip install git+
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'

# Confusion matrix
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

import torchsummary
from tqdm import tqdm

import umap

from import load_cifar_10
from astra.torch.utils import train_fn
from astra.torch.models import ResNetClassifier

# Netron, ONNX for model visualization
    import netron
except ModuleNotFoundError:
    %pip install netron
    import netron

    import onnx
except ModuleNotFoundError:
    %pip install onnx
    import onnx

import copy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


dataset = load_cifar_10()
CIFAR-10 Dataset
length of dataset: 60000
shape of images: torch.Size([3, 32, 32])
len of classes: 10
classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dtype of images: torch.float32
dtype of labels: torch.int64
# Plot some images
plt.figure(figsize=(6, 6))
for i in range(25):
    plt.subplot(5, 5, i+1)

Data splitting

n_train = 1000
n_test = 20000

X =
y = dataset.targets

idx = torch.randperm(len(X))
train_idx = idx[:n_train]
pool_idx = idx[n_train:-n_test]
test_idx = idx[-n_test:]
resnet = ResNetClassifier(models.resnet18, models.ResNet18_Weights.DEFAULT, n_classes=10).to(device)
torchsummary.summary(resnet, (3, 32, 32))
def get_accuracy(net, X, y):
    # Set the net to evaluation mode
    with torch.no_grad():
        logits_pred = net(X)
        y_pred = logits_pred.argmax(dim=1)
        acc = (y_pred == y).float().mean()
    return y_pred, acc

def predict(net, classes, plot_confusion_matrix=False):
    for i, (name, idx) in enumerate(zip(("train", "pool", "test"), [train_idx, pool_idx, test_idx])):
        X_dataset = X[idx].to(device)
        y_dataset = y[idx].to(device)
        y_pred, acc = get_accuracy(net, X_dataset, y_dataset)
        print(f'{name} set accuracy: {acc*100:.2f}%')
        if plot_confusion_matrix:
            cm = confusion_matrix(y_dataset.cpu(), y_pred.cpu())
            cm_display = ConfusionMatrixDisplay(cm, display_labels=classes).plot(values_format='d'
                                                                                , cmap='Blues')
            # Rotate the labels on x-axis to make them readable
            _ = plt.xticks(rotation=90)
predict(resnet, dataset.classes, plot_confusion_matrix=True)
def viz_embeddings(net, X, y, device):
    reducer = umap.UMAP()
    with torch.no_grad():
        emb = net.featurizer(
    emb = emb.cpu().numpy()
    emb = reducer.fit_transform(emb)
    plt.figure(figsize=(4, 4))
    plt.scatter(emb[:, 0], emb[:, 1], c=y.cpu().numpy(), cmap='tab10')
    # Add a colorbar legend to mark color to class mapping
    cb = plt.colorbar(boundaries=np.arange(11)-0.5)
    plt.title("UMAP embeddings")

viz_embeddings(resnet, X[train_idx], y[train_idx], device)

Train the model on train set

model_only_train = ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)
iter_losses, epoch_losses = train_fn(model_only_train, X[train_idx], y[train_idx], nn.CrossEntropyLoss(), lr=3e-4, 
                                     batch_size=128, epochs=30, verbose=False)
plt.ylabel("Training loss")
predict(model_only_train, dataset.classes, plot_confusion_matrix=True)
viz_embeddings(model_only_train, X[train_idx], y[train_idx], device)

viz_embeddings(model_only_train, X[test_idx[:1000]], y[test_idx[:1000]], device)

Train on train + pool

train_plus_pool_idx =[train_idx, pool_idx])

model_train_plus_pool = ResNetClassifier(models.resnet18, None, n_classes=10, activation=nn.GELU(), dropout=0.1).to(device)

iter_losses, epoch_losses = train_fn(model_train_plus_pool, X[train_plus_pool_idx], y[train_plus_pool_idx], loss_fn=nn.CrossEntropyLoss(),
                                        batch_size=1024, epochs=30)
plt.ylabel("Training loss")
viz_embeddings(model_train_plus_pool, X[train_idx], y[train_idx], device)

viz_embeddings(model_train_plus_pool, X[test_idx[:1000]], y[test_idx[:1000]], device)

predict(model_train_plus_pool, dataset.classes, plot_confusion_matrix=True)
train set accuracy: 97.80%
pool set accuracy: 98.03%
test set accuracy: 61.14%

accuracy_only_train = get_accuracy(model_only_train, X[test_idx].to(device), y[test_idx].to(device))[1]
accuracy_train_plus_pool = get_accuracy(model_train_plus_pool, X[test_idx].to(device), y[test_idx].to(device))[1]

plt.axhline(accuracy_only_train.cpu(), color='r', label='Only train')
plt.axhline(accuracy_train_plus_pool.cpu(), color='g', label='Train + pool')

Active learning loop

def setdiff1d(a, b):
    mask = ~a.unsqueeze(1).eq(b).any(dim=1)
    return torch.masked_select(a, mask)

a = torch.tensor([1, 2, 3, 4, 5])
b = torch.tensor([1, 3, 5])

def al_loop(model, query_strategy, num_al_iterations, num_epochs_finetune,
            train_idx, pool_idx, test_idx, query_size, 
            X, y, device, random_seed=0, verbose=False):
    model: PyTorch model trained on train_idx
    query_strategy: function that takes in model and pool_idx, 
                train_idx and returns indices to query
    num_al_iterations: number of active learning iterations
    num_epochs_finetune: number of epochs to train on queried data + train_idx
    train_idx: indices of data used for training
    pool_idx: indices of data used for querying
    test_idx: indices of data used for testing
    query_size: number of data points to query at each iteration
    X: data
    y: labels
    device: torch device
    random_seed: random seed
    verbose: print statements
    tr_idx = train_idx.clone()
    p_idx = pool_idx.clone()
    print(f"Initial train size: {train_idx.shape}")
    print(f"Initial pool size: {pool_idx.shape}")
    # Initial model test accuracy
    init_accuracy = get_accuracy(model, X[test_idx].to(device), y[test_idx].to(device))[1].item()
    print(f"Test accuracy before AL: {init_accuracy:0.4f}")
    # Test accuracies
    test_accuracies = {0: init_accuracy}
    for iteration in range(num_al_iterations):
        # Query
        query_idx = query_strategy(model, p_idx, tr_idx, random_seed, query_size, X, y, device)
        # Add queried data to train_idx
        tr_idx =[tr_idx, query_idx])
        # Remove queried data from pool_idx
        p_idx = setdiff1d(p_idx, query_idx)
        # Retrain model on pooled data
        iter_losses, epoch_losses = train_fn(model, X[tr_idx], y[tr_idx], loss_fn=nn.CrossEntropyLoss(),
                                             lr=3e-4, batch_size=1024,
        test_accuracies[iteration+1] = get_accuracy(model, X[test_idx].to(device), y[test_idx].to(device))[1].item()
        if verbose:
            print(f"Active learning iteration {iteration+1}/{num_al_iterations}")
            print(f"Train set size: {len(tr_idx)}, Pool set size: {len(p_idx)}")
            print(f"Test accuracy: {test_accuracies[iteration]:0.4f}")
    return model, tr_idx, p_idx, test_accuracies      
def random_sampling(model, pool_idx, train_idx, random_seed, query_size, X, y, device):
    query_idx = pool_idx[torch.randperm(len(pool_idx))[:query_size]]
    return query_idx
query_size = 20
num_al_iterations = 50

import copy

model_r = copy.deepcopy(model_only_train)

model, t_idx, p_idx, test_acc_random = al_loop(model_r, random_sampling, num_al_iterations, 20, train_idx, pool_idx, test_idx, query_size, X, y,device=device,verbose=True)
### Now, running across multiple random seeds

query_size = 20
num_al_iterations = 50
ms = {}
t_idxs = {}
p_idxs = {}
test_acc_random = {}
for rs in range(5):
    print(f"Random seed: {rs}")
    model = copy.deepcopy(model_only_train)
    ms[rs], t_idxs[rs], p_idxs[rs], test_acc_random[rs] = al_loop(model, random_sampling, num_al_iterations, 20, train_idx, pool_idx, test_idx, query_size, X, y,device=device,verbose=False, random_seed=rs)
test_acc_random_df = pd.DataFrame(test_acc_random)
mean_acc = test_acc_random_df.mean(axis=1)
std_acc = test_acc_random_df.std(axis=1)
plt.plot(mean_acc, label="Random sampling (mean)")
plt.fill_between(mean_acc.index, mean_acc-std_acc,
                 mean_acc+std_acc, alpha=0.2, label="Random sampling (std)")

# Accuracy of model trained on train_idx
plt.axhline(accuracy_only_train.cpu(), color='r', label='Only train')

# Accuracy of model trained on train_idx + pool_idx
plt.axhline(accuracy_train_plus_pool.cpu(), color='g', label='Train + pool')
plt.xlabel("Active learning iteration")
plt.ylabel("Test accuracy")
def entropy_sampling(model, pool_idx, train_idx, random_seed, query_size, X, y, device):
    with torch.no_grad():
        logits = model(X[pool_idx].to(device)) # (len(pool_idx), n_classes)
        probs = F.softmax(logits, dim=1)
        entropy = torch.sum(-probs * torch.log(probs), dim=1)
        entropy_sorted = entropy.sort(descending=True)
        query_idx =[entropy_sorted.indices[:query_size]]
    return query_idx.cpu()
entropy_sampling(model_only_train, pool_idx, train_idx, 0, 5, X, y, device)
model_e = copy.deepcopy(model_only_train)
# AL loop

m, p_idx, t_idx, test_acc_entropy  = al_loop(model_e, entropy_sampling, num_al_iterations, 20, train_idx, pool_idx, test_idx, query_size, 
                                             X, y,device=device,verbose=True, random_seed=0)
fig, ax = plt.subplots(1, 1)
ax.plot(mean_acc, label="Random sampling (mean)")
ax.fill_between(mean_acc.index, mean_acc-std_acc,
                 mean_acc+std_acc, alpha=0.2, label="Random sampling (std)")
ax.axhline(accuracy_only_train.cpu(), color='r', label='Only train')
ax.axhline(accuracy_train_plus_pool.cpu(), color='g', label='Train + pool')

pd.Series(test_acc_entropy).plot(ax=ax, label="Entropy sampling", color='C2')

ax.set_xlabel("Active learning iteration")
ax.set_ylabel("Test accuracy")
def margin_sampling(model, pool_idx, train_idx, random_seed, query_size, X, y, device):
    with torch.no_grad():
        logits = model(X[pool_idx].to(device))    
        probs = F.softmax(logits, dim=1)
        margin = torch.topk(probs, 2, dim=1).values
        margin = margin[:, 0] - margin[:, 1]
        margin_sorted = margin.sort(descending=True)
        query_idx =[margin_sorted.indices[:query_size]]
    return query_idx.cpu()
margin_sampling(model_only_train, pool_idx, train_idx, 0, 5, X, y, device)
def diversity_sampling(model, pool_idx, train_idx, random_seed, query_size, X, y, device):
    with torch.no_grad():
        emb_pool = model.featurizer(X[pool_idx].to(device))
        emb_train = model.featurizer(X[train_idx].to(device))
        # Find the distance between each pool point and each train point
        dist = torch.cdist(emb_pool, emb_train)
        # Find the minimum distance for each pool point
        min_dist = dist.min(dim=1).values
        # Sort the pool points by minimum distance
        min_dist_sorted = min_dist.sort(descending=True)
        query_idx =[min_dist_sorted.indices[:query_size]]
    return query_idx.cpu()
diversity_sampling(model_only_train, pool_idx, train_idx, 0, 5, X, y, device)
# BALD sample dataset to illustrate the idea

pred_A = torch.tensor([0.5]*10).reshape(-1, 1, 1).repeat(1, 1, 2)
def BALD_score(logits):
    logits: (n_MC_passes, n_samples, n_classes)
    probs = F.softmax(logits, dim=2)
    expected_probs = probs.mean(dim=0) # Expectation over MC passes
    if bald_verbose:
    entropy_expected_probs = torch.sum(-expected_probs * torch.log(expected_probs), dim=1)
    if bald_verbose:
    entropy_probs = torch.sum(-probs * torch.log(probs), dim=2)
    if bald_verbose:
    expected_entropy_probs = entropy_probs.mean(dim=0)
    if bald_verbose:
    bald_score = entropy_expected_probs - expected_entropy_probs
    if bald_verbose:
    return bald_score
bald_verbose = True
# Entropy
entropy_A = torch.sum(-pred_A.mean(dim=0) * torch.log(pred_A.mean(dim=0)), dim=1)
entropy_B = torch.sum(-pred_B.mean(dim=0) * torch.log(pred_B.mean(dim=0)), dim=1)

def BALD_sampling(model, pool_idx, train_idx, random_seed, query_size, X, y, device):
    model: MC dropout model
    # Evaluate the logits on the pool set for each MC pass
    n_MC_passes = 8
    logits = []
    with torch.no_grad():
        for mc_pass in range(n_MC_passes):
            # Set mode of model for MC dropout
    logits = torch.stack(logits)
    # print(logits.shape)
    bald_score = BALD_score(logits)
    bald_score_sorted = bald_score.sort(descending=True)
    query_idx =[bald_score_sorted.indices[:query_size]]
    return query_idx.cpu()
model_e = copy.deepcopy(model_only_train)
# AL loop

m, p_idx, t_idx, test_acc_margin  = al_loop(model_e, margin_sampling, num_al_iterations, 20, train_idx, pool_idx, test_idx, query_size, X, y,device=device,verbose=True, random_seed=0)
fig, ax = plt.subplots(1, 1)
ax.plot(mean_acc, label="Random sampling (mean)")
ax.fill_between(mean_acc.index, mean_acc-std_acc,
                 mean_acc+std_acc, alpha=0.2, label="Random sampling (std)")
ax.axhline(accuracy_only_train.cpu(), color='r', label='Only train')
ax.axhline(accuracy_train_plus_pool.cpu(), color='g', label='Train + pool')

pd.Series(test_acc_entropy).plot(ax=ax, label="Entropy sampling", color='C2')
pd.Series(test_acc_margin).plot(ax=ax, label="Margin sampling", color='C4')

ax.set_xlabel("Active learning iteration")
ax.set_ylabel("Test accuracy")
model_e = copy.deepcopy(model_only_train)
# AL loop

m, p_idx, t_idx, test_acc_diversity  = al_loop(model_e, diversity_sampling, num_al_iterations, 20, train_idx, pool_idx, test_idx, query_size, X, y,device=device,verbose=True, random_seed=0)
fig, ax = plt.subplots(1, 1)
ax.plot(mean_acc, label="Random sampling (mean)")
ax.fill_between(mean_acc.index, mean_acc-std_acc,
                 mean_acc+std_acc, alpha=0.2, label="Random sampling (std)")
ax.axhline(accuracy_only_train.cpu(), color='r', label='Only train')
ax.axhline(accuracy_train_plus_pool.cpu(), color='g', label='Train + pool')

pd.Series(test_acc_entropy).plot(ax=ax, label="Entropy sampling", color='C2')
pd.Series(test_acc_margin).plot(ax=ax, label="Margin sampling", color='C4')
pd.Series(test_acc_diversity).plot(ax=ax, label="Diversity sampling", color='C5')

ax.set_xlabel("Active learning iteration")
ax.set_ylabel("Test accuracy")
bald_verbose = False
model_e = copy.deepcopy(model_only_train)
# AL loop

m, p_idx, t_idx, test_acc_bald  = al_loop(model_e, BALD_sampling, num_al_iterations, 20, train_idx, pool_idx, test_idx, query_size, X, y,device=device,verbose=True, random_seed=0)
fig, ax = plt.subplots(1, 1)
ax.plot(mean_acc, label="Random sampling (mean)")
ax.fill_between(mean_acc.index, mean_acc-std_acc,
                 mean_acc+std_acc, alpha=0.2, label="Random sampling (std)")
ax.axhline(accuracy_only_train.cpu(), color='r', label='Only train')
ax.axhline(accuracy_train_plus_pool.cpu(), color='g', label='Train + pool')

pd.Series(test_acc_entropy).plot(ax=ax, label="Entropy sampling", color='C2')
pd.Series(test_acc_margin).plot(ax=ax, label="Margin sampling", color='C4')
pd.Series(test_acc_diversity).plot(ax=ax, label="Diversity sampling", color='C5')
pd.Series(test_acc_bald).plot(ax=ax, label="BALD sampling", color='C6')

ax.set_xlabel("Active learning iteration")
ax.set_ylabel("Test accuracy")
