Modular filtering tested and working with varying filter strengths
This commit is contained in:
parent
ab4460aeeb
commit
7430ddac94
BIN
Filter_Analysis/__pycache__/defense_filters.cpython-311.pyc
Normal file
BIN
Filter_Analysis/__pycache__/defense_filters.cpython-311.pyc
Normal file
Binary file not shown.
@ -1,6 +1,7 @@
|
||||
import cv2
|
||||
from pykuwahara import kuwahara
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@ -11,6 +12,8 @@ def pttensor_to_images(data):
|
||||
images = data.numpy().transpose(0,2,3,1)
|
||||
except RuntimeError:
|
||||
images = data.detach().numpy().transpose(0,2,3,1)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def gaussian_kuwahara(data, batch_size=64, radius=5):
|
||||
@ -97,7 +100,7 @@ def threshold_filter(data, batch_size=64, threshold=0.5):
|
||||
return torch.tensor(filtered_images).float()
|
||||
|
||||
|
||||
def snap_colors(data, batch_size=64, quantizations=4)
|
||||
def snap_colors(data, batch_size=64, quantizations=4):
|
||||
images = pttensor_to_images(data)
|
||||
filtered_images = np.ndarray((batch_size,28,28,1))
|
||||
|
||||
|
@ -6,9 +6,13 @@ from torchvision import datasets, transforms
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
from mnist import Net
|
||||
from pykuwahara import kuwahara
|
||||
|
||||
import defense_filters
|
||||
|
||||
|
||||
|
||||
TESTED_STRENGTH_COUNT = 5
|
||||
|
||||
MAX_EPSILON = 0.3
|
||||
EPSILON_STEP = 0.025
|
||||
@ -30,7 +34,6 @@ print("CUDA Available: ", torch.cuda.is_available())
|
||||
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = Net().to(device)
|
||||
print(type(model))
|
||||
|
||||
model.load_state_dict(torch.load(pretrained_model, map_location=device))
|
||||
|
||||
@ -71,19 +74,18 @@ def denorm(batch, mean=[0.1307], std=[0.3081]):
|
||||
def test(model, device, test_loader, epsilon):
|
||||
# Original dataset correct classifications
|
||||
orig_correct = 0
|
||||
|
||||
# Attacked dataset correct classifications
|
||||
unfiltered_correct = 0
|
||||
kuwahara_correct = 0
|
||||
bilateral_correct = 0
|
||||
gaussian_blur_correct = 0
|
||||
random_noise_correct = 0
|
||||
snap_color_correct = 0
|
||||
one_bit_correct = 0
|
||||
plurality_correct = 0
|
||||
|
||||
adv_examples = []
|
||||
|
||||
# Attacked, filtered dataset correct classifications
|
||||
filtered_correct_counts = []
|
||||
|
||||
test_step = 0
|
||||
for data, target in test_loader:
|
||||
print("[" + "="*int(20*test_step/len(test_loader)) + " "*(20 - int(20*test_step/len(test_loader))) + "]", f"{100*test_step/len(test_loader)}%", end='\r')
|
||||
test_step += 1
|
||||
|
||||
data, target = data.to(device), target.to(device)
|
||||
data.requires_grad = True
|
||||
|
||||
@ -111,34 +113,31 @@ def test(model, device, test_loader, epsilon):
|
||||
# Reapply normalization
|
||||
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
|
||||
|
||||
# Filter the attacked image
|
||||
kuwahara_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="kuwahara")
|
||||
bilateral_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="bilateral")
|
||||
gaussian_blur_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="gaussian_blur")
|
||||
random_noise_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="noise")
|
||||
snap_color_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="snap_color")
|
||||
one_bit_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="1-bit")
|
||||
|
||||
# evaluate the model on the attacked and filtered images
|
||||
# Evaluate the model on the attacked image
|
||||
output_unfiltered = model(perturbed_data_normalized)
|
||||
output_kuwahara = model(kuwahara_data)
|
||||
output_bilateral = model(bilateral_data)
|
||||
output_gaussian_blur = model(gaussian_blur_data)
|
||||
output_random_noise = model(random_noise_data)
|
||||
output_snap_color = model(snap_color_data)
|
||||
output_one_bit = model(one_bit_data)
|
||||
|
||||
# Get the predicted class from the model for each case
|
||||
# Evaluate performance for
|
||||
for i in range(TESTED_STRENGTH_COUNT):
|
||||
strength = 2*i + 1
|
||||
|
||||
# Apply the filter with the specified strength
|
||||
filtered_input = defense_filters.gaussian_kuwahara(perturbed_data_normalized, batch_size=len(perturbed_data_normalized), radius=strength)
|
||||
# Evaluate the model on the filtered images
|
||||
filtered_output = model(filtered_input)
|
||||
# Get the predicted classification
|
||||
filtered_pred = filtered_output.max(1, keepdim=True)[1]
|
||||
|
||||
# Count up correct classifications
|
||||
if filtered_pred.item() == target.item():
|
||||
if len(filtered_correct_counts) == i:
|
||||
filtered_correct_counts.append(1)
|
||||
else:
|
||||
filtered_correct_counts[i] += 1
|
||||
|
||||
|
||||
# Get the predicted classification
|
||||
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
|
||||
kuwahara_pred = output_kuwahara.max(1, keepdim=True)[1]
|
||||
bilateral_pred = output_bilateral.max(1, keepdim=True)[1]
|
||||
gaussian_blur_pred = output_gaussian_blur.max(1, keepdim=True)[1]
|
||||
random_noise_pred = output_random_noise.max(1, keepdim=True)[1]
|
||||
snap_color_pred = output_snap_color.max(1, keepdim=True)[1]
|
||||
one_bit_pred = output_one_bit.max(1, keepdim=True)[1]
|
||||
predictions = [unfiltered_pred.item(), kuwahara_pred.item(), bilateral_pred.item(), gaussian_blur_pred.item(), random_noise_pred.item(), snap_color_pred.item(), one_bit_pred.item()]
|
||||
plurality_pred = stats.mode(predictions, keepdims=True)[0]
|
||||
|
||||
|
||||
# Count up correct classifications for each case
|
||||
if orig_pred.item() == target.item():
|
||||
orig_correct += 1
|
||||
@ -146,85 +145,39 @@ def test(model, device, test_loader, epsilon):
|
||||
if unfiltered_pred.item() == target.item():
|
||||
unfiltered_correct += 1
|
||||
|
||||
if kuwahara_pred.item() == target.item():
|
||||
kuwahara_correct += 1
|
||||
|
||||
if bilateral_pred.item() == target.item():
|
||||
bilateral_correct += 1
|
||||
|
||||
if gaussian_blur_pred.item() == target.item():
|
||||
gaussian_blur_correct += 1
|
||||
|
||||
if random_noise_pred.item() == target.item():
|
||||
random_noise_correct += 1
|
||||
|
||||
if snap_color_pred.item() == target.item():
|
||||
snap_color_correct += 1
|
||||
|
||||
if one_bit_pred.item() == target.item():
|
||||
one_bit_correct += 1
|
||||
|
||||
if plurality_pred == target.item():
|
||||
plurality_correct += 1
|
||||
|
||||
# Calculate the overall accuracy of each case
|
||||
orig_acc = orig_correct/float(len(test_loader))
|
||||
unfiltered_acc = unfiltered_correct/float(len(test_loader))
|
||||
kuwahara_acc = kuwahara_correct/float(len(test_loader))
|
||||
bilateral_acc = bilateral_correct/float(len(test_loader))
|
||||
gaussian_blur_acc = gaussian_blur_correct/float(len(test_loader))
|
||||
random_noise_acc = random_noise_correct/float(len(test_loader))
|
||||
snap_color_acc = snap_color_correct/float(len(test_loader))
|
||||
one_bit_acc = one_bit_correct/float(len(test_loader))
|
||||
plurality_acc = plurality_correct/float(len(test_loader))
|
||||
|
||||
filtered_accuracies = []
|
||||
for correct_count in filtered_correct_counts:
|
||||
filtered_accuracies.append(correct_count/float(len(test_loader)))
|
||||
|
||||
|
||||
print(f"====== EPSILON: {epsilon} ======")
|
||||
print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}")
|
||||
print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}")
|
||||
print(f"Kuwahara Filter Accuracy = {kuwahara_correct} / {len(test_loader)} = {kuwahara_acc}")
|
||||
print(f"Bilateral Filter Accuracy = {bilateral_correct} / {len(test_loader)} = {bilateral_acc}")
|
||||
print(f"Gaussian Blur Accuracy = {gaussian_blur_correct} / {len(test_loader)} = {gaussian_blur_acc}")
|
||||
print(f"Random Noise Accuracy = {random_noise_correct} / {len(test_loader)} = {random_noise_acc}")
|
||||
print(f"Snapped Color Accuracy = {snap_color_correct} / {len(test_loader)} = {snap_color_acc}")
|
||||
print(f"1 Bit Accuracy = {one_bit_correct} / {len(test_loader)} = {one_bit_acc}")
|
||||
print(f"Plurality Vote Accuracy = {plurality_correct} / {len(test_loader)} = {plurality_acc}")
|
||||
|
||||
return unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc
|
||||
for i in range(TESTED_STRENGTH_COUNT):
|
||||
strength = 2*i + 1
|
||||
print(f"Gaussian Kuwahara (strength = {strength}) = {filtered_correct_counts[i]} / {len(test_loader)} = {filtered_accuracies[i]}")
|
||||
|
||||
return unfiltered_acc, filtered_accuracies
|
||||
|
||||
|
||||
unfiltered_accuracies = []
|
||||
kuwahara_accuracies = []
|
||||
bilateral_accuracies = []
|
||||
gaussian_blur_accuracies = []
|
||||
random_noise_accuracies = []
|
||||
snap_color_accuracies = []
|
||||
one_bit_accuracies = []
|
||||
pluality_vote_accuracies = []
|
||||
|
||||
print(f"Model: {pretrained_model}")
|
||||
filtered_accuracies = []
|
||||
|
||||
for eps in epsilons:
|
||||
unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc = test(model, device, test_loader, eps)
|
||||
unfiltered_accuracies.append(unfiltered_acc)
|
||||
kuwahara_accuracies.append(kuwahara_acc)
|
||||
bilateral_accuracies.append(bilateral_acc)
|
||||
gaussian_blur_accuracies.append(gaussian_blur_acc)
|
||||
random_noise_accuracies.append(random_noise_acc)
|
||||
snap_color_accuracies.append(snap_color_acc)
|
||||
one_bit_accuracies.append(one_bit_acc)
|
||||
pluality_vote_accuracies.append(plurality_acc)
|
||||
unfiltered_accuracy, filtered_accuracy = test(model, device, test_loader, eps)
|
||||
unfiltered_accuracies.append(unfiltered_accuracy)
|
||||
filtered_accuracies.append(filtered_accuracy)
|
||||
|
||||
# Plot the results
|
||||
plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy")
|
||||
plt.plot(epsilons, kuwahara_accuracies, label="Kuwahara Accuracy")
|
||||
plt.plot(epsilons, bilateral_accuracies, label="Bilateral Accuracy")
|
||||
plt.plot(epsilons, gaussian_blur_accuracies, label="Gaussian Blur Accuracy")
|
||||
plt.plot(epsilons, random_noise_accuracies, label="Random Noise Accuracy")
|
||||
plt.plot(epsilons, snap_color_accuracies, label="Snapped Color Accuracy")
|
||||
plt.plot(epsilons, one_bit_accuracies, label="1-Bit Accuracy")
|
||||
plt.plot(epsilons, pluality_vote_accuracies, label="Plurality Vote Accuracy")
|
||||
for i in range(TESTED_STRENGTH_COUNT):
|
||||
plt.plot(epsilons, filtered_accuracies[i], label=f"Gaussian Kuwahara (strength = {2*i + 1})")
|
||||
|
||||
plt.legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user