Adversarial-Machine-Learnin.../Filter_Analysis/fgsm.py
2024-04-10 12:23:55 -04:00

294 lines
12 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import cv2
from mnist import Net
from pykuwahara import kuwahara
MAX_EPSILON = 0.3
EPSILON_STEP = 0.025
epsilons = np.arange(0.0, MAX_EPSILON+EPSILON_STEP, EPSILON_STEP)
pretrained_model = "mnist_cnn_unfiltered.pt"
use_cuda=False
torch.manual_seed(69)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data/', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])),
batch_size=1, shuffle=True)
print("CUDA Available: ", torch.cuda.is_available())
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
model = Net().to(device)
print(type(model))
model.load_state_dict(torch.load(pretrained_model, map_location=device))
model.eval()
def fgsm_attack(image, epsilon, data_grad):
# Collect the element-wise sign of the data gradient
sign_data_grad = data_grad.sign()
# Create the perturbed image by adjusting each pixel of the input image
perturbed_image = image + epsilon*sign_data_grad
# Adding clipping to maintain [0, 1] range
perturbed_image = torch.clamp(perturbed_image, 0, 1)
return perturbed_image
def denorm(batch, mean=[0.1307], std=[0.3081]):
"""
Convert a batch of tensors to their original scale.
Args:
batch (torch.Tensor): Batch of normalized tensors.
mean (torch.Tensor or list): Man used for normalization.
std (torch.Tensor or list): Standard deviation used for normalization.
Returns:
torch.Tensor: batch of tensors without normalization applied to them.
"""
if isinstance(mean, list):
mean = torch.tensor(mean).to(device)
if isinstance(std, list):
std = torch.tensor(std).to(device)
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
def test(model, device, test_loader, epsilon):
# Original dataset correct classifications
orig_correct = 0
# Attacked dataset correct classifications
unfiltered_correct = 0
kuwahara_correct = 0
bilateral_correct = 0
gaussian_blur_correct = 0
random_noise_correct = 0
snap_color_correct = 0
one_bit_correct = 0
plurality_correct = 0
adv_examples = []
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data.requires_grad = True
output_orig = model(data)
orig_pred = output_orig.max(1, keepdim=True)[1]
# Calculate the loss
loss = F.nll_loss(output_orig, target)
# Zero all existing gradients
model.zero_grad()
# Calculate gradients of model in backward pass
loss.backward()
# Collect ''datagrad''
data_grad = data.grad.data
# Restore the data to its original scale
data_denorm = denorm(data)
# Apply the FGSM attack
perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
# Reapply normalization
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
# Filter the attacked image
kuwahara_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="kuwahara")
bilateral_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="bilateral")
gaussian_blur_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="gaussian_blur")
random_noise_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="noise")
snap_color_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="snap_color")
one_bit_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="1-bit")
# evaluate the model on the attacked and filtered images
output_unfiltered = model(perturbed_data_normalized)
output_kuwahara = model(kuwahara_data)
output_bilateral = model(bilateral_data)
output_gaussian_blur = model(gaussian_blur_data)
output_random_noise = model(random_noise_data)
output_snap_color = model(snap_color_data)
output_one_bit = model(one_bit_data)
# Get the predicted class from the model for each case
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
kuwahara_pred = output_kuwahara.max(1, keepdim=True)[1]
bilateral_pred = output_bilateral.max(1, keepdim=True)[1]
gaussian_blur_pred = output_gaussian_blur.max(1, keepdim=True)[1]
random_noise_pred = output_random_noise.max(1, keepdim=True)[1]
snap_color_pred = output_snap_color.max(1, keepdim=True)[1]
one_bit_pred = output_one_bit.max(1, keepdim=True)[1]
predictions = [unfiltered_pred.item(), kuwahara_pred.item(), bilateral_pred.item(), gaussian_blur_pred.item(), random_noise_pred.item(), snap_color_pred.item(), one_bit_pred.item()]
plurality_pred = stats.mode(predictions, keepdims=True)[0]
# Count up correct classifications for each case
if orig_pred.item() == target.item():
orig_correct += 1
if unfiltered_pred.item() == target.item():
unfiltered_correct += 1
if kuwahara_pred.item() == target.item():
kuwahara_correct += 1
if bilateral_pred.item() == target.item():
bilateral_correct += 1
if gaussian_blur_pred.item() == target.item():
gaussian_blur_correct += 1
if random_noise_pred.item() == target.item():
random_noise_correct += 1
if snap_color_pred.item() == target.item():
snap_color_correct += 1
if one_bit_pred.item() == target.item():
one_bit_correct += 1
if plurality_pred == target.item():
plurality_correct += 1
# Calculate the overall accuracy of each case
orig_acc = orig_correct/float(len(test_loader))
unfiltered_acc = unfiltered_correct/float(len(test_loader))
kuwahara_acc = kuwahara_correct/float(len(test_loader))
bilateral_acc = bilateral_correct/float(len(test_loader))
gaussian_blur_acc = gaussian_blur_correct/float(len(test_loader))
random_noise_acc = random_noise_correct/float(len(test_loader))
snap_color_acc = snap_color_correct/float(len(test_loader))
one_bit_acc = one_bit_correct/float(len(test_loader))
plurality_acc = plurality_correct/float(len(test_loader))
print(f"====== EPSILON: {epsilon} ======")
print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}")
print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}")
print(f"Kuwahara Filter Accuracy = {kuwahara_correct} / {len(test_loader)} = {kuwahara_acc}")
print(f"Bilateral Filter Accuracy = {bilateral_correct} / {len(test_loader)} = {bilateral_acc}")
print(f"Gaussian Blur Accuracy = {gaussian_blur_correct} / {len(test_loader)} = {gaussian_blur_acc}")
print(f"Random Noise Accuracy = {random_noise_correct} / {len(test_loader)} = {random_noise_acc}")
print(f"Snapped Color Accuracy = {snap_color_correct} / {len(test_loader)} = {snap_color_acc}")
print(f"1 Bit Accuracy = {one_bit_correct} / {len(test_loader)} = {one_bit_acc}")
print(f"Plurality Vote Accuracy = {plurality_correct} / {len(test_loader)} = {plurality_acc}")
return unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc
def filtered(data, batch_size=64, filter="kuwahara"):
# Turn the tensor into an image
images = None
try:
images = data.numpy().transpose(0,2,3,1)
except RuntimeError:
images = data.detach().numpy().transpose(0,2,3,1)
# Apply the Kuwahara filter
filtered_images = np.ndarray((batch_size,28,28,1))
if filter == "kuwahara":
for i in range(batch_size):
filtered_images[i] = kuwahara(images[i], method='gaussian', radius=5, image_2d=images[i])
elif filter == "aniso_diff":
for i in range(batch_size):
img_3ch = np.zeros((np.array(images[i]), np.array(images[i]).shape[1], 3))
img_3ch[:,:,0] = images[i]
img_3ch[:,:,1] = images[i]
img_3ch[:,:,2] = images[i]
img_3ch_filtered = cv2.ximgproc.anisotropicDiffusion(img2, alpha=0.2, K=0.5, niters=5)
filtered_images[i] = cv2.cvtColor(img_3ch_filtered, cv2.COLOR_RGB2GRAY)
plt.imshow(filtered_images[i])
plt.show()
elif filter == "noise":
for i in range(batch_size):
mean = 0
stddev = 180
noise = np.zeros(images[i].shape, images[i].dtype)
cv2.randn(noise, mean, stddev)
filtered_images[i] = cv2.addWeighted(images[i], 1.0, noise, 0.001, 0.0).reshape(filtered_images[i].shape)
elif filter == "gaussian_blur":
for i in range(batch_size):
filtered_images[i] = cv2.GaussianBlur(images[i], ksize=(5,5), sigmaX=0).reshape(filtered_images[i].shape)
elif filter == "bilateral":
for i in range(batch_size):
filtered_images[i] = cv2.bilateralFilter(images[i], 5, 50, 50).reshape(filtered_images[i].shape)
elif filter == "1-bit":
num_colors = 2
for i in range(batch_size):
# If the channel contains any negative values, define the lowest negative value as black
min_value = np.min(images[i])
if (min_value < 0):
filtered_images[i] = images[i] + min_value
# If the color space extends beyond [0,1], re-scale all of the colors to that range
max_value = np.max(filtered_images[i])
if (max_value > 1):
filtered_images[i] *= (num_colors/max_value)
filtered_images[i] = filtered_images[i].astype(int).astype(float)*(max_value/num_colors)
else:
filtered_images[i] *= num_colors
filtered_images[i] = filtered_images[i].astype(int).astype(float)/num_colors
if (min_value < 0):
filtered_images[i] -= min_value
elif filter == "snap_color":
for i in range(batch_size):
filtered_images[i] = (images[i]*4).astype(int).astype(float)/4
# Modify the data with the filtered image
filtered_images = filtered_images.transpose(0,3,1,2)
return torch.tensor(filtered_images).float()
unfiltered_accuracies = []
kuwahara_accuracies = []
bilateral_accuracies = []
gaussian_blur_accuracies = []
random_noise_accuracies = []
snap_color_accuracies = []
one_bit_accuracies = []
pluality_vote_accuracies = []
print(f"Model: {pretrained_model}")
for eps in epsilons:
unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc = test(model, device, test_loader, eps)
unfiltered_accuracies.append(unfiltered_acc)
kuwahara_accuracies.append(kuwahara_acc)
bilateral_accuracies.append(bilateral_acc)
gaussian_blur_accuracies.append(gaussian_blur_acc)
random_noise_accuracies.append(random_noise_acc)
snap_color_accuracies.append(snap_color_acc)
one_bit_accuracies.append(one_bit_acc)
pluality_vote_accuracies.append(plurality_acc)
# Plot the results
plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy")
plt.plot(epsilons, kuwahara_accuracies, label="Kuwahara Accuracy")
plt.plot(epsilons, bilateral_accuracies, label="Bilateral Accuracy")
plt.plot(epsilons, gaussian_blur_accuracies, label="Gaussian Blur Accuracy")
plt.plot(epsilons, random_noise_accuracies, label="Random Noise Accuracy")
plt.plot(epsilons, snap_color_accuracies, label="Snapped Color Accuracy")
plt.plot(epsilons, one_bit_accuracies, label="1-Bit Accuracy")
plt.plot(epsilons, pluality_vote_accuracies, label="Plurality Vote Accuracy")
plt.legend()
plt.show()