import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import numpy as np from scipy import stats import matplotlib.pyplot as plt from mnist import Net import json import sys import defense_filters TESTED_STRENGTH_COUNT = 5 MAX_EPSILON = 0.3 EPSILON_STEP = 0.025 epsilons = np.arange(0.0, MAX_EPSILON+EPSILON_STEP, EPSILON_STEP) pretrained_model = "mnist_cnn_unfiltered.pt" use_cuda=False torch.manual_seed(69) test_loader = torch.utils.data.DataLoader( datasets.MNIST('data/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ])), batch_size=1, shuffle=True) print("CUDA Available: ", torch.cuda.is_available()) device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") model = Net().to(device) model.load_state_dict(torch.load(pretrained_model, map_location=device)) model.eval() def fgsm_attack(image, epsilon, data_grad): # Collect the element-wise sign of the data gradient sign_data_grad = data_grad.sign() # Create the perturbed image by adjusting each pixel of the input image perturbed_image = image + epsilon*sign_data_grad # Adding clipping to maintain [0, 1] range perturbed_image = torch.clamp(perturbed_image, 0, 1) return perturbed_image def denorm(batch, mean=[0.1307], std=[0.3081]): """ Convert a batch of tensors to their original scale. Args: batch (torch.Tensor): Batch of normalized tensors. mean (torch.Tensor or list): Man used for normalization. std (torch.Tensor or list): Standard deviation used for normalization. Returns: torch.Tensor: batch of tensors without normalization applied to them. """ if isinstance(mean, list): mean = torch.tensor(mean).to(device) if isinstance(std, list): std = torch.tensor(std).to(device) return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1) def test(model, device, test_loader, epsilon, filter): # Original dataset correct classifications orig_correct = 0 # Attacked dataset correct classifications unfiltered_correct = 0 # Attacked, filtered dataset correct classifications filtered_correct_counts = [] test_step = 0 for data, target in test_loader: sys.stdout.write("\033[K") print(filter, f"Epsilon: {epsilon}", "[" + "="*int(1 + 20*test_step/len(test_loader)) + " "*(20 - int(20*test_step/len(test_loader))) + "]", f"{100*test_step/len(test_loader)}%", end='\r') test_step += 1 data, target = data.to(device), target.to(device) data.requires_grad = True output_orig = model(data) orig_pred = output_orig.max(1, keepdim=True)[1] # Calculate the loss loss = F.nll_loss(output_orig, target) # Zero all existing gradients model.zero_grad() # Calculate gradients of model in backward pass loss.backward() # Collect ''datagrad'' data_grad = data.grad.data # Restore the data to its original scale data_denorm = denorm(data) # Apply the FGSM attack perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad) # Reapply normalization perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data) # Evaluate the model on the attacked image output_unfiltered = model(perturbed_data_normalized) # Evaluate performance for for i in range(TESTED_STRENGTH_COUNT): # Apply the filter with the specified strength filtered_input = defense_filters.filtered(perturbed_data_normalized, batch_size=len(perturbed_data_normalized), strength=i, filter=filter) # Evaluate the model on the filtered images filtered_output = model(filtered_input) # Get the predicted classification filtered_pred = filtered_output.max(1, keepdim=True)[1] # Count up correct classifications if filtered_pred.item() == target.item(): while i >= len(filtered_correct_counts): filtered_correct_counts.append(0) filtered_correct_counts[i] += 1 # Get the predicted classification unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1] # Count up correct classifications for each case if orig_pred.item() == target.item(): orig_correct += 1 if unfiltered_pred.item() == target.item(): unfiltered_correct += 1 # Calculate the overall accuracy of each case orig_acc = orig_correct/float(len(test_loader)) unfiltered_acc = unfiltered_correct/float(len(test_loader)) filtered_accuracies = [] for correct_count in filtered_correct_counts: filtered_accuracies.append(correct_count/float(len(test_loader))) #print(f"====== EPSILON: {epsilon} ======") #print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}") #print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}") #for i in range(TESTED_STRENGTH_COUNT): # strength = i+1 # print(f"{filter}({strength}) = {filtered_correct_counts[i]} / {len(test_loader)} = {filtered_accuracies[i]}") return unfiltered_acc, filtered_accuracies accuracies = {} filters = ("gaussian_blur", "gaussian_kuwahara", "mean_kuwahara", "random_noise", "bilateral_filter", "bit_depth", "threshold_filter") for filter in filters: for eps in epsilons: unfiltered_accuracy, filtered_accuracy = test(model, device, test_loader, eps, filter) if list(accuracies.keys()).count("unfiltered") == 0: accuracies["unfiltered"] = [] if len(accuracies["unfiltered"]) < len(epsilons): accuracies["unfiltered"].append(unfiltered_accuracy) if list(accuracies.keys()).count(filter) == 0: accuracies[filter] = [] accuracies[filter].append(filtered_accuracy) accuracies_json = json.dumps(accuracies, indent=4) with open("results/mnist_fgsm.json", "w") as outfile: outfile.write(accuracies_json) # Plot the results #plt.figure(figsize=(16,9)) #plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy") #for i in range(TESTED_STRENGTH_COUNT): # filtered_accuracy = [filter_eps[i] for filter_eps in filtered_accuracies] # plt.plot(epsilons, filtered_accuracy, label=f"Bit Depth = {i + 1})") # #plt.legend(loc="upper right") #plt.title("Bit-Depth Reduction Performance") #plt.xlabel("Attack Strength ($\\epsilon$)") #plt.ylabel("Accuracy") #plt.savefig("Images/BitDepthReducePerformance.png", )