Adversarial-Machine-Learnin.../Filter_Analysis/fgsm.py

200 lines
6.7 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from mnist import Net
import json
import sys
import defense_filters
TESTED_STRENGTH_COUNT = 5
MAX_EPSILON = 0.3
EPSILON_STEP = 0.025
epsilons = np.arange(0.0, MAX_EPSILON+EPSILON_STEP, EPSILON_STEP)
pretrained_model = "mnist_cnn_unfiltered.pt"
use_cuda=False
torch.manual_seed(69)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data/', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])),
batch_size=1, shuffle=True)
print("CUDA Available: ", torch.cuda.is_available())
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
model = Net().to(device)
model.load_state_dict(torch.load(pretrained_model, map_location=device))
model.eval()
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient
sign_data_grad = data_grad.sign()
# Create the perturbed image by adjusting each pixel of the input image
perturbed_image = image + epsilon*sign_data_grad
# Adding clipping to maintain [0, 1] range
perturbed_image = torch.clamp(perturbed_image, 0, 1)
return perturbed_image
def denorm(batch, mean=[0.1307], std=[0.3081]):
"""
Convert a batch of tensors to their original scale.
Args:
batch (torch.Tensor): Batch of normalized tensors.
mean (torch.Tensor or list): Man used for normalization.
std (torch.Tensor or list): Standard deviation used for normalization.
Returns:
torch.Tensor: batch of tensors without normalization applied to them.
"""
if isinstance(mean, list):
mean = torch.tensor(mean).to(device)
if isinstance(std, list):
std = torch.tensor(std).to(device)
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
def test(model, device, test_loader, epsilon, filter):
# Original dataset correct classifications
orig_correct = 0
# Attacked dataset correct classifications
unfiltered_correct = 0
# Attacked, filtered dataset correct classifications
filtered_correct_counts = []
test_step = 0
for data, target in test_loader:
sys.stdout.write("\033[K")
print(filter, f"Epsilon: {epsilon}", "[" + "="*int(1 + 20*test_step/len(test_loader)) + " "*(20 - int(20*test_step/len(test_loader))) + "]", f"{100*test_step/len(test_loader)}%", end='\r')
test_step += 1
data, target = data.to(device), target.to(device)
data.requires_grad = True
output_orig = model(data)
orig_pred = output_orig.max(1, keepdim=True)[1]
# Calculate the loss
loss = F.nll_loss(output_orig, target)
# Zero all existing gradients
model.zero_grad()
# Calculate gradients of model in backward pass
loss.backward()
# Collect ''datagrad''
data_grad = data.grad.data
# Restore the data to its original scale
data_denorm = denorm(data)
# Apply the FGSM attack
perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
# Reapply normalization
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
# Evaluate the model on the attacked image
output_unfiltered = model(perturbed_data_normalized)
# Evaluate performance for
for i in range(TESTED_STRENGTH_COUNT):
# Apply the filter with the specified strength
filtered_input = defense_filters.filtered(perturbed_data_normalized, batch_size=len(perturbed_data_normalized), strength=i, filter=filter)
# Evaluate the model on the filtered images
filtered_output = model(filtered_input)
# Get the predicted classification
filtered_pred = filtered_output.max(1, keepdim=True)[1]
# Count up correct classifications
if filtered_pred.item() == target.item():
while i >= len(filtered_correct_counts):
filtered_correct_counts.append(0)
filtered_correct_counts[i] += 1
# Get the predicted classification
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
# Count up correct classifications for each case
if orig_pred.item() == target.item():
orig_correct += 1
if unfiltered_pred.item() == target.item():
unfiltered_correct += 1
# Calculate the overall accuracy of each case
orig_acc = orig_correct/float(len(test_loader))
unfiltered_acc = unfiltered_correct/float(len(test_loader))
filtered_accuracies = []
for correct_count in filtered_correct_counts:
filtered_accuracies.append(correct_count/float(len(test_loader)))
#print(f"====== EPSILON: {epsilon} ======")
#print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}")
#print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}")
#for i in range(TESTED_STRENGTH_COUNT):
# strength = i+1
# print(f"{filter}({strength}) = {filtered_correct_counts[i]} / {len(test_loader)} = {filtered_accuracies[i]}")
return unfiltered_acc, filtered_accuracies
accuracies = {}
filters = ("gaussian_blur", "gaussian_kuwahara", "mean_kuwahara", "random_noise", "bilateral_filter", "bit_depth", "threshold_filter")
for filter in filters:
for eps in epsilons:
unfiltered_accuracy, filtered_accuracy = test(model, device, test_loader, eps, filter)
if list(accuracies.keys()).count("unfiltered") == 0:
accuracies["unfiltered"] = []
if len(accuracies["unfiltered"]) < len(epsilons):
accuracies["unfiltered"].append(unfiltered_accuracy)
if list(accuracies.keys()).count(filter) == 0:
accuracies[filter] = []
accuracies[filter].append(filtered_accuracy)
accuracies_json = json.dumps(accuracies, indent=4)
with open("results/mnist_fgsm.json", "w") as outfile:
outfile.write(accuracies_json)
# Plot the results
#plt.figure(figsize=(16,9))
#plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy")
#for i in range(TESTED_STRENGTH_COUNT):
# filtered_accuracy = [filter_eps[i] for filter_eps in filtered_accuracies]
# plt.plot(epsilons, filtered_accuracy, label=f"Bit Depth = {i + 1})")
#
#plt.legend(loc="upper right")
#plt.title("Bit-Depth Reduction Performance")
#plt.xlabel("Attack Strength ($\\epsilon$)")
#plt.ylabel("Accuracy")
#plt.savefig("Images/BitDepthReducePerformance.png", )