import os
import time
import numpy as np
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
Adversarial Examples for ML
Data, model and attack utilities
The colab environment already has all the necessary Python packages installed. Specifically, we are using numpy, torch and torchvision.
# Choosing backend
if torch.backends.mps.is_available():
=torch.device("mps")
deviceelif torch.cuda.is_available():
=torch.device("cuda")
deviceelse:
=torch.device("cpu") device
Loading data
We load in the data using the in-built data loaders in PyTorch. It offers functionality for many commonly used computer vision datasets, but we will just use MNIST (a dataset of black and white handwritten digits) for now.
def load_dataset(dataset, data_dir, training_time):
if dataset == 'CIFAR-10':
= load_cifar_dataset(data_dir, training_time)
loader_train, loader_test, data_details elif 'MNIST' in dataset:
= load_mnist_dataset(data_dir, training_time)
loader_train, loader_test, data_details else:
raise ValueError('No support for dataset %s' % args.dataset)
return loader_train, loader_test, data_details
def load_mnist_dataset(data_dir, training_time):
# MNIST data loaders
= datasets.MNIST(root=data_dir, train=True,
trainset =True, transform=transforms.ToTensor())
download= datasets.MNIST(root=data_dir, train=False,
testset =True, transform=transforms.ToTensor())
download
= torch.utils.data.DataLoader(trainset,
loader_train =128,
batch_size=True)
shuffle
= torch.utils.data.DataLoader(testset,
loader_test =128,
batch_size=False)
shuffle= {'n_channels':1, 'h_in':28, 'w_in':28, 'scale':255.0}
data_details return loader_train, loader_test, data_details
Having defined the data loaders, we now create the data loaders to be used throughout, as well as a dictionary with the details of the dataset, in case we need it.
= load_dataset('MNIST','data',training_time=True) loader_train, loader_test, data_details
Common path and definitions for train/test
Since we need the path to the directory where we will storing our models (pre-trained or not), and we also need to instantiate a copy of the model we defined above, we will run the following commands to have everything setup for test/evaluation.
Defining the model
We use a 2-layer fully connected network for the experiments in this tutorial. The definition of a 3 layer convolutional neural network is also provided. The former is sufficient for MNIST, but may not be large enough for more complex tasks.
='fcn' model_name
class cnn_3l_bn(nn.Module):
def __init__(self, n_classes=10):
super(cnn_3l_bn, self).__init__()
#in-channels, no. filters, filter size, stride
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.bn1 = nn.BatchNorm2d(20)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.bn2 = nn.BatchNorm2d(50)
# Number of neurons in preceding layer, Number in current layer
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, n_classes)
def forward(self, x):
# Rectified linear unit activation
= F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
x return F.log_softmax(x, dim=1)
class fcn(nn.Module):
def __init__(self, n_classes=10):
super(fcn, self).__init__()
self.fc1 = nn.Linear(784,200)
self.fc2 = nn.Linear(200,200)
self.fc3 = nn.Linear(200,n_classes)
def forward(self, x):
= x.view(-1, 28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x return F.log_softmax(x, dim=1)
if 'fcn' in model_name:
='models'+'/'+'MNIST'+'/fcn/'
model_dir_nameif not os.path.exists(model_dir_name):
os.makedirs(model_dir_name)
# Basic setup
= fcn(10)
net elif 'cnn' in model_name:
='models'+'/'+'MNIST'+'/cnn_3l_bn/'
model_dir_nameif not os.path.exists(model_dir_name):
os.makedirs(model_dir_name)
# Basic setup
= cnn_3l_bn(10)
net
net.to(device)
= nn.CrossEntropyLoss(reduction='none') criterion
Training the benign/standard model
This is sample code for training your own model. Since it takes time to run, for the purposes of the tutorial, we will assume we already have trained models.
######################################## Benign/standard training ########################################
def train_one_epoch(model, optimizer, loader_train, verbose=True):
= []
losses
model.train()for t, (x, y) in enumerate(loader_train):
x.to(device)
y.to(device)= Variable(x, requires_grad= True).to(device)
x_var = Variable(y, requires_grad= False).to(device)
y_var = model(x_var)
scores # loss = loss_fn(scores, y_var)
= nn.CrossEntropyLoss(reduction='none')
loss_function = loss_function(scores, y_var)
batch_loss = torch.mean(batch_loss)
loss
losses.append(loss.data.cpu().numpy())
optimizer.zero_grad()
loss.backward()# print(model.conv1.weight.grad)
optimizer.step()if verbose:
print('loss = %.8f' % (loss.data))
return np.mean(losses)
Actual training loop
We define the necessary parameters for training (batch size, learning rate etc.), instantiate the optimizer and then train for 50 epochs.
In each epoch, the model is trained using all of the training data, which is split into batches of size 128. Thus, one step of the optimizer uses 128 samples, and there are a total of 50*(50,000/128) steps in the entire process.
# Training parameters
=128
batch_size=0.1 #
learning_rate=2e-4
weight_decay=True
save_checkpoint
# Torch optimizer
= torch.optim.SGD(net.parameters(),
optimizer =learning_rate,
lr=0.9,
momentum=weight_decay)
weight_decay
# if args.lr_schedule == 'cosine':
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
# T_max=args.train_epochs, eta_min=0, last_epoch=-1)
# elif args.lr_schedule == 'linear0':
= torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150,200], gamma=0.1) scheduler
for epoch in range(0, 10):
= time.time()
start_time # lr = update_hyparam(epoch, args)
= optimizer.param_groups[0]['lr']
lr print('Current learning rate: {}'.format(lr))
# if not args.is_adv:
= train_one_epoch(net, optimizer,
ben_loss =False)
loader_train, verboseprint('time_taken for #{} epoch = {:.3f}'.format(epoch+1, time.time()-start_time))
if save_checkpoint:
= 'checkpoint_' + str(0)
ckpt_path + ckpt_path)
torch.save(net.state_dict(), model_dir_name print('Train loss - Ben: %s' %
(ben_loss)) scheduler.step()
Generating Adversarial Examples
We will look at how to generate adversarial examples using the Projected Gradient Descent (PGD) method for the model we have trained and visualize the adversarial examples thus generated.
# Attack utils
# Random initialization within the L2 ball
def rand_init_l2(img_variable, eps_max):
= torch.FloatTensor(*img_variable.shape).normal_(0, 1).to(device)
random_vec = torch.max(
random_vec_norm 0), -1).norm(2, 1), torch.tensor(1e-9).to(device))
random_vec.view(random_vec.size(= random_vec/random_vec_norm.view(random_vec.size(0),1,1,1)
random_dir = torch.FloatTensor(img_variable.size(0)).uniform_(0, eps_max).to(device)
random_scale = random_scale.view(random_vec.size(0),1,1,1)*random_dir
random_noise = Variable(img_variable.data + random_noise, requires_grad=True).to(device)
img_variable
return img_variable
# Random initialization within the L_inf ball
def rand_init_linf(img_variable, eps_max):
= torch.FloatTensor(*img_variable.shape).uniform_(-eps_max, eps_max).to(device)
random_noise = Variable(img_variable.data + random_noise, requires_grad=True).to(device)
img_variable
return img_variable
# Tracking the best adversarial examples during the generation process
def track_best(blosses, b_adv_x, curr_losses, curr_adv_x):
if blosses is None:
= curr_adv_x.clone().detach()
b_adv_x = curr_losses.clone().detach()
blosses else:
= curr_losses < blosses
replace = curr_adv_x[replace].clone().detach()
b_adv_x[replace] = curr_losses[replace]
blosses[replace]
return blosses.to(device), b_adv_x.to(device)
# Loss calculation
def cal_loss(y_out, y_true, targeted):
= torch.nn.CrossEntropyLoss(reduction='none')
losses = losses(y_out, y_true).to(device)
losses_cal = torch.mean(losses_cal).to(device)
loss_cal if targeted:
return loss_cal, losses_cal
else:
return -1*loss_cal, -1*losses_cal
# Generating targets for each adversarial example
def generate_target_label_tensor(true_label, n_classes):
= torch.floor(n_classes*torch.rand(true_label.shape)).type(torch.int64)
t = t == true_label
m = (t[m]+ torch.ceil((n_classes-1)*torch.rand(t[m].shape)).type(torch.int64)) % n_classes
t[m] return t.to(device)
This provides the core loop of the attack algorithm which goes as follows: 1. The perturbation is initialized to 0. 2. The gradient of the model with respect to the current state of the adversarial example is found. 3. The gradient is appropriately normalized and added to the current state of the example 4. The complete adversarial example is clipped to lie within the input bounds 5. Steps 2,3 and 4 are repeated for a fixed number of steps or until some condition is met
# Attack code
def pgd_attack(model, image_tensor, img_variable, tar_label_variable,
n_steps, eps_max, eps_step, clip_min, clip_max, targeted, rand_init):"""
image_tensor: tensor which holds the clean images.
img_variable: Corresponding pytorch variable for image_tensor.
tar_label_variable: Assuming targeted attack, this variable holds the targeted labels.
n_steps: number of attack iterations.
eps_max: maximum l_inf attack perturbations.
eps_step: l_inf attack perturbation per step
"""
= None
best_losses = None
best_adv_x = image_tensor.to(device)
image_tensor
if rand_init:
= rand_init_linf(img_variable, eps_max)
img_variable
= model.forward(img_variable)
output for i in range(n_steps):
if img_variable.grad is not None:
img_variable.grad.zero_()= model.forward(img_variable)
output = cal_loss(output, tar_label_variable, targeted)
loss_cal, losses_cal = track_best(best_losses, best_adv_x, losses_cal, img_variable)
best_losses, best_adv_x
= cal_loss(output, tar_label_variable, targeted)
loss_cal, losses_cal
loss_cal.backward()# Finding the gradient of the loss
= -1 * eps_step * torch.sign(img_variable.grad.data)
x_grad # Adding gradient to current state of the example
= img_variable.data + x_grad
adv_temp = adv_temp - image_tensor
total_grad = torch.clamp(total_grad, -eps_max, eps_max)
total_grad = image_tensor + total_grad
x_adv # Projecting adversarial example back onto the constraint set
= torch.clamp(torch.clamp(
x_adv -image_tensor, -1*eps_max, eps_max)+image_tensor, clip_min, clip_max)
x_adv= x_adv
img_variable.data
= track_best(best_losses, best_adv_x, losses_cal, img_variable)
best_losses, best_adv_x
return best_adv_x
def pgd_l2_attack(model, image_tensor, img_variable, tar_label_variable,
n_steps, eps_max, eps_step, clip_min, clip_max, targeted,
rand_init, num_restarts):"""
image_tensor: tensor which holds the clean images.
img_variable: Corresponding pytorch variable for image_tensor.
tar_label_variable: Assuming targeted attack, this variable holds the targeted labels.
n_steps: number of attack iterations.
eps_max: maximum l_inf attack perturbations.
eps_step: l_inf attack perturbation per step
"""
= None
best_losses = None
best_adv_x = image_tensor.clone().detach()
image_tensor_orig = tar_label_variable.clone().detach()
tar_label_orig
for j in range(num_restarts):
if rand_init:
= rand_init_l2(img_variable, eps_max)
img_variable
= model.forward(img_variable)
output for i in range(n_steps):
if img_variable.grad is not None:
img_variable.grad.zero_()= model.forward(img_variable)
output = cal_loss(output, tar_label_variable, targeted)
loss_cal, losses_cal = track_best(best_losses, best_adv_x, losses_cal, img_variable)
best_losses, best_adv_x
loss_cal.backward()= img_variable.grad.data
raw_grad = torch.max(
grad_norm 0), -1).norm(2, 1), torch.tensor(1e-9))
raw_grad.view(raw_grad.size(= raw_grad/grad_norm.view(raw_grad.size(0),1,1,1)
grad_dir = img_variable.data + -1 * eps_step * grad_dir
adv_temp # Clipping total perturbation
= adv_temp - image_tensor
total_grad = torch.max(
total_grad_norm 0), -1).norm(2, 1), torch.tensor(1e-9))
total_grad.view(total_grad.size(= total_grad/total_grad_norm.view(total_grad.size(0),1,1,1)
total_grad_dir = torch.min(total_grad_norm, torch.tensor(eps_max))
total_grad_norm_rescale = total_grad_norm_rescale.view(total_grad.size(0),1,1,1) * total_grad_dir
clipped_grad = image_tensor + clipped_grad
x_adv = torch.clamp(x_adv, clip_min, clip_max)
x_adv = x_adv
img_variable.data
= track_best(best_losses, best_adv_x, losses_cal, img_variable)
best_losses, best_adv_x
= np.array(x_adv.cpu())-np.array(image_tensor.data.cpu())
diff_array = diff_array.reshape(len(diff_array),-1)
diff_array
= image_tensor_orig
img_variable.data
return best_adv_x
Now we can call the core adversarial example generation function over our data and model to determine how robust the model actually is!
def robust_test(model, loss_fn, loader, att_dir, n_batches=0, train_data=False,
=False):
training_time"""
n_batches (int): Number of batches for evaluation.
"""
eval()
model.= 0, 0, 0
num_correct, num_correct_adv, num_samples = 1
steps = []
losses_adv = []
losses_ben = []
adv_images = []
adv_labels = []
clean_images = []
correct_labels
for t, (x, y) in enumerate(loader):
=x.to(device)
x=y.to(device)
y= Variable(x, requires_grad= True).to(device)
x_var = Variable(y, requires_grad=False).to(device)
y_var if att_dir['targeted']:
= generate_target_label_tensor(
y_target 10).to(device)
y_var.cpu(), else:
= y_var
y_target if 'PGD_linf' in att_dir['attack']:
= pgd_attack(model, x, x_var, y_target, att_dir['attack_iter'],
adv_x 'epsilon'], att_dir['eps_step'], att_dir['clip_min'],
att_dir['clip_max'], att_dir['targeted'], att_dir['rand_init'])
att_dir[elif 'PGD_l2' in att_dir['attack']:
= pgd_l2_attack(model, x, x_var, y_target, att_dir['attack_iter'],
adv_x 'epsilon'], att_dir['eps_step'], att_dir['clip_min'],
att_dir['clip_max'], att_dir['targeted'], att_dir['rand_init'],
att_dir['num_restarts'])
att_dir[# Predictions
# scores = model(x.cuda())
= model(x)
scores = scores.data.max(1)
_, preds = model(adv_x)
scores_adv = scores_adv.data.max(1)
_, preds_adv # Losses
= loss_fn(scores_adv, y)
batch_loss_adv = torch.mean(batch_loss_adv)
loss_adv
losses_adv.append(loss_adv.data.cpu().numpy())= loss_fn(scores, y)
batch_loss_ben = torch.mean(batch_loss_ben)
loss_ben
losses_ben.append(loss_ben.data.cpu().numpy())# Correct count
+= (preds == y).sum()
num_correct += (preds_adv == y).sum()
num_correct_adv += len(preds)
num_samples # Adding images and labels to list
adv_images.extend(adv_x)
adv_labels.extend(preds_adv)
clean_images.extend(x)
correct_labels.extend(preds)
if n_batches > 0 and steps==n_batches:
break
+= 1
steps
= float(num_correct) / num_samples
acc = float(num_correct_adv) / num_samples
acc_adv print('Clean accuracy: {:.2f}% ({}/{})'.format(
100.*acc,
num_correct,
num_samples,
))print('Adversarial accuracy: {:.2f}% ({}/{})'.format(
100.*acc_adv,
num_correct_adv,
num_samples,
))
return 100.*acc, 100.*acc_adv, np.mean(losses_ben), np.mean(losses_adv), adv_images, adv_labels, clean_images, correct_labels
The most important parameters below are epsilon (which controls the magnitude of the perturbation), gamma (which determines how far outside the constraint set intial search is allowed) and attack_iter (which is just the number of attack iterations).
# Attack setup
=2.5
gamma=0.2
epsilon=10
attack_iter=epsilon*gamma/attack_iter
delta= {'attack': 'PGD_linf', 'epsilon': epsilon,
attack_params 'attack_iter': 10, 'eps_step': delta,
'targeted': True, 'clip_min': 0.0,
'clip_max': 1.0,'rand_init': True,
'num_restarts': 1}
Now, we load the model (remember to first upload it into the models folder!) and then generate adversarial examples.
= 'checkpoint_' + str(0)
ckpt_path
net.to(device)eval()
net.+ ckpt_path, map_location=device))
net.load_state_dict(torch.load(model_dir_name = 10
n_batches_eval print('Test set validation')
# Running validation
= robust_test(net,
acc_test, acc_adv_test, test_loss, test_loss_adv, adv_images, adv_labels, clean_images, correct_labels =n_batches_eval,
criterion, loader_test, attack_params, n_batches=False, training_time=True)
train_data# print('Training set validation')
# acc_train, acc_adv_train, train_loss, train_loss_adv, _ = robust_test(net,
# criterion, loader_train_all, args, attack_params, n_batches=n_batches_eval,
# train_data=True, training_time=True)
Visualizing the adversarial examples
from matplotlib import pyplot as plt
%matplotlib inline
= plt.figure(figsize=(9, 13))
fig = 4
columns = 5
rows
# ax enables access to manipulate each of subplots
= []
ax
for i in range(columns*rows):
=int(i/2)
image_countif i%2==1:
= adv_images[image_count].reshape(28,28).cpu()
img # create subplot and append to ax
+1) )
ax.append( fig.add_subplot(rows, columns, i-1].set_title("output:"+str(adv_labels[image_count].cpu().numpy())) # set title
ax[-1].set_xticks([])
ax[-1].set_yticks([])
ax[else:
= clean_images[image_count].reshape(28,28).cpu()
img # create subplot and append to ax
+1) )
ax.append( fig.add_subplot(rows, columns, i-1].set_title("output:"+str(correct_labels[image_count].cpu().numpy())) # set title
ax[-1].set_xticks([])
ax[-1].set_yticks([])
ax[='nearest',cmap='gray')
plt.imshow(img, interpolation
# finally, render the plot plt.show()
Training robust models
This training loop is very similar to the benign one, except that we now call the adversarial example generation function to generate adversarial examples during the training process.
######################################## Adversarial training ########################################
def robust_train_one_epoch(model, optimizer, loader_train, att_dir,
epoch):# print('Current eps: {}, delta: {}'.format(eps, delta))
= []
losses_adv = []
losses_ben
model.train()for t, (x, y) in enumerate(loader_train):
=x.to(device)
x=y.to(device)
y= Variable(x, requires_grad= True)
x_var = Variable(y, requires_grad= False)
y_var if att_dir['targeted']:
= generate_target_label_tensor(
y_target 10).to(device)
y_var.cpu(), else:
= y_var
y_target if 'PGD_linf' in att_dir['attack']:
= pgd_attack(model, x, x_var, y_target, att_dir['attack_iter'],
adv_x 'epsilon'], att_dir['eps_step'], att_dir['clip_min'],
att_dir['clip_max'], att_dir['targeted'], att_dir['rand_init'])
att_dir[elif 'PGD_l2' in att_dir['attack']:
= pgd_l2_attack(model, x, x_var, y_target, att_dir['attack_iter'],
adv_x 'epsilon'], att_dir['eps_step'], att_dir['clip_min'],
att_dir['clip_max'], att_dir['targeted'], att_dir['rand_init'],
att_dir['num_restarts'])
att_dir[= model(adv_x)
scores = nn.CrossEntropyLoss(reduction='none')
loss_function = loss_function(scores, y_var)
batch_loss_adv = loss_function(model(x),y_var)
batch_loss_ben = torch.mean(batch_loss_adv)
loss = torch.mean(batch_loss_ben)
loss_ben
losses_ben.append(loss_ben.data.cpu().numpy())
losses_adv.append(loss.data.cpu().numpy())# GD step
optimizer.zero_grad()
loss.backward()# print(model.conv1.weight.grad)
optimizer.step()return np.mean(losses_adv), np.mean(losses_ben)
for epoch in range(0, 10):
= time.time()
start_time # lr = update_hyparam(epoch, args)
= optimizer.param_groups[0]['lr']
lr print('Current learning rate: {}'.format(lr))
= robust_train_one_epoch(net,
curr_loss, ben_loss
optimizer, loader_train, attack_params,
epoch)print('time_taken for #{} epoch = {:.3f}'.format(epoch+1, time.time()-start_time))
if save_checkpoint:
= 'checkpoint_adv' + str(0)
ckpt_path + ckpt_path)
torch.save(net.state_dict(), model_dir_name print('Train loss - Ben: %s, Adv: %s' %
(ben_loss, curr_loss)) scheduler.step()
Evaluating the robust model
Evaluating the robust model, we find its accuracy on adversarial examples has increased significantly!
= 'checkpoint_adv' + str(0)
ckpt_path eval()
net.+ ckpt_path, map_location=device))
net.load_state_dict(torch.load(model_dir_name = 10
n_batches_eval print('Test set validation')
# Running validation
= robust_test(net,
acc_test_r, acc_adv_test_r, test_loss_r, test_loss_adv_r, adv_images_r, adv_labels_r, clean_images_r, correct_labels_r =n_batches_eval,
criterion, loader_test, attack_params, n_batches=False, training_time=True)
train_data# print('Training set validation')
# acc_train, acc_adv_train, train_loss, train_loss_adv, _ = robust_test(net,
# criterion, loader_train_all, args, attack_params, n_batches=n_batches_eval,
# train_data=True, training_time=True)
= plt.figure(figsize=(9, 13))
fig = 4
columns = 5
rows
# ax enables access to manipulate each of subplots
= []
ax
for i in range(columns*rows):
=int(i/2)
image_countif i%2==1:
= adv_images_r[image_count].reshape(28,28).cpu()
img # create subplot and append to ax
+1) )
ax.append( fig.add_subplot(rows, columns, i-1].set_title("output:"+str(adv_labels_r[image_count].cpu().numpy())) # set title
ax[-1].set_xticks([])
ax[-1].set_yticks([])
ax[else:
= clean_images_r[image_count].reshape(28,28).cpu()
img # create subplot and append to ax
+1) )
ax.append( fig.add_subplot(rows, columns, i-1].set_title("output:"+str(correct_labels_r[image_count].cpu().numpy())) # set title
ax[-1].set_xticks([])
ax[-1].set_yticks([])
ax[='nearest',cmap='gray')
plt.imshow(img, interpolation
# finally, render the plot plt.show()
Discussion questions
- Doesn’t robust training solve the problem of adversarial examples? Why is there still so much research on the topic?
- How would a real-world attacker try to carry out this attack without access to the classifier being used?
- What does the existence of adversarial examples tell us about modern ML models?