import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import numpy as np from scipy import stats import matplotlib.pyplot as plt import cv2 from mnist import Net from pykuwahara import kuwahara MAX_EPSILON = 0.3 EPSILON_STEP = 0.025 epsilons = np.arange(0.0, MAX_EPSILON+EPSILON_STEP, EPSILON_STEP) pretrained_model = "mnist_cnn_unfiltered.pt" use_cuda=False torch.manual_seed(69) test_loader = torch.utils.data.DataLoader( datasets.MNIST('data/', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ])), batch_size=1, shuffle=True) print("CUDA Available: ", torch.cuda.is_available()) device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") model = Net().to(device) print(type(model)) model.load_state_dict(torch.load(pretrained_model, map_location=device)) model.eval() def fgsm_attack(image, epsilon, data_grad): # Collect the element-wise sign of the data gradient sign_data_grad = data_grad.sign() # Create the perturbed image by adjusting each pixel of the input image perturbed_image = image + epsilon*sign_data_grad # Adding clipping to maintain [0, 1] range perturbed_image = torch.clamp(perturbed_image, 0, 1) return perturbed_image def denorm(batch, mean=[0.1307], std=[0.3081]): """ Convert a batch of tensors to their original scale. Args: batch (torch.Tensor): Batch of normalized tensors. mean (torch.Tensor or list): Man used for normalization. std (torch.Tensor or list): Standard deviation used for normalization. Returns: torch.Tensor: batch of tensors without normalization applied to them. """ if isinstance(mean, list): mean = torch.tensor(mean).to(device) if isinstance(std, list): std = torch.tensor(std).to(device) return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1) def test(model, device, test_loader, epsilon): # Original dataset correct classifications orig_correct = 0 # Attacked dataset correct classifications unfiltered_correct = 0 kuwahara_correct = 0 bilateral_correct = 0 gaussian_blur_correct = 0 random_noise_correct = 0 snap_color_correct = 0 one_bit_correct = 0 plurality_correct = 0 adv_examples = [] for data, target in test_loader: data, target = data.to(device), target.to(device) data.requires_grad = True output_orig = model(data) orig_pred = output_orig.max(1, keepdim=True)[1] # Calculate the loss loss = F.nll_loss(output_orig, target) # Zero all existing gradients model.zero_grad() # Calculate gradients of model in backward pass loss.backward() # Collect ''datagrad'' data_grad = data.grad.data # Restore the data to its original scale data_denorm = denorm(data) # Apply the FGSM attack perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad) # Reapply normalization perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data) # Filter the attacked image kuwahara_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="kuwahara") bilateral_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="bilateral") gaussian_blur_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="gaussian_blur") random_noise_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="noise") snap_color_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="snap_color") one_bit_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="1-bit") # evaluate the model on the attacked and filtered images output_unfiltered = model(perturbed_data_normalized) output_kuwahara = model(kuwahara_data) output_bilateral = model(bilateral_data) output_gaussian_blur = model(gaussian_blur_data) output_random_noise = model(random_noise_data) output_snap_color = model(snap_color_data) output_one_bit = model(one_bit_data) # Get the predicted class from the model for each case unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1] kuwahara_pred = output_kuwahara.max(1, keepdim=True)[1] bilateral_pred = output_bilateral.max(1, keepdim=True)[1] gaussian_blur_pred = output_gaussian_blur.max(1, keepdim=True)[1] random_noise_pred = output_random_noise.max(1, keepdim=True)[1] snap_color_pred = output_snap_color.max(1, keepdim=True)[1] one_bit_pred = output_one_bit.max(1, keepdim=True)[1] predictions = [unfiltered_pred.item(), kuwahara_pred.item(), bilateral_pred.item(), gaussian_blur_pred.item(), random_noise_pred.item(), snap_color_pred.item(), one_bit_pred.item()] plurality_pred = stats.mode(predictions, keepdims=True)[0] # Count up correct classifications for each case if orig_pred.item() == target.item(): orig_correct += 1 if unfiltered_pred.item() == target.item(): unfiltered_correct += 1 if kuwahara_pred.item() == target.item(): kuwahara_correct += 1 if bilateral_pred.item() == target.item(): bilateral_correct += 1 if gaussian_blur_pred.item() == target.item(): gaussian_blur_correct += 1 if random_noise_pred.item() == target.item(): random_noise_correct += 1 if snap_color_pred.item() == target.item(): snap_color_correct += 1 if one_bit_pred.item() == target.item(): one_bit_correct += 1 if plurality_pred == target.item(): plurality_correct += 1 # Calculate the overall accuracy of each case orig_acc = orig_correct/float(len(test_loader)) unfiltered_acc = unfiltered_correct/float(len(test_loader)) kuwahara_acc = kuwahara_correct/float(len(test_loader)) bilateral_acc = bilateral_correct/float(len(test_loader)) gaussian_blur_acc = gaussian_blur_correct/float(len(test_loader)) random_noise_acc = random_noise_correct/float(len(test_loader)) snap_color_acc = snap_color_correct/float(len(test_loader)) one_bit_acc = one_bit_correct/float(len(test_loader)) plurality_acc = plurality_correct/float(len(test_loader)) print(f"====== EPSILON: {epsilon} ======") print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}") print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}") print(f"Kuwahara Filter Accuracy = {kuwahara_correct} / {len(test_loader)} = {kuwahara_acc}") print(f"Bilateral Filter Accuracy = {bilateral_correct} / {len(test_loader)} = {bilateral_acc}") print(f"Gaussian Blur Accuracy = {gaussian_blur_correct} / {len(test_loader)} = {gaussian_blur_acc}") print(f"Random Noise Accuracy = {random_noise_correct} / {len(test_loader)} = {random_noise_acc}") print(f"Snapped Color Accuracy = {snap_color_correct} / {len(test_loader)} = {snap_color_acc}") print(f"1 Bit Accuracy = {one_bit_correct} / {len(test_loader)} = {one_bit_acc}") print(f"Plurality Vote Accuracy = {plurality_correct} / {len(test_loader)} = {plurality_acc}") return unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc unfiltered_accuracies = [] kuwahara_accuracies = [] bilateral_accuracies = [] gaussian_blur_accuracies = [] random_noise_accuracies = [] snap_color_accuracies = [] one_bit_accuracies = [] pluality_vote_accuracies = [] print(f"Model: {pretrained_model}") for eps in epsilons: unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc = test(model, device, test_loader, eps) unfiltered_accuracies.append(unfiltered_acc) kuwahara_accuracies.append(kuwahara_acc) bilateral_accuracies.append(bilateral_acc) gaussian_blur_accuracies.append(gaussian_blur_acc) random_noise_accuracies.append(random_noise_acc) snap_color_accuracies.append(snap_color_acc) one_bit_accuracies.append(one_bit_acc) pluality_vote_accuracies.append(plurality_acc) # Plot the results plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy") plt.plot(epsilons, kuwahara_accuracies, label="Kuwahara Accuracy") plt.plot(epsilons, bilateral_accuracies, label="Bilateral Accuracy") plt.plot(epsilons, gaussian_blur_accuracies, label="Gaussian Blur Accuracy") plt.plot(epsilons, random_noise_accuracies, label="Random Noise Accuracy") plt.plot(epsilons, snap_color_accuracies, label="Snapped Color Accuracy") plt.plot(epsilons, one_bit_accuracies, label="1-Bit Accuracy") plt.plot(epsilons, pluality_vote_accuracies, label="Plurality Vote Accuracy") plt.legend() plt.show()