Modular filtering tested and working with varying filter strengths

This commit is contained in:
Aidan Sharpe 2024-04-11 14:24:48 -04:00
parent ab4460aeeb
commit 7430ddac94
3 changed files with 56 additions and 100 deletions

View File

@ -1,6 +1,7 @@
import cv2 import cv2
from pykuwahara import kuwahara from pykuwahara import kuwahara
import numpy as np import numpy as np
import torch
@ -11,6 +12,8 @@ def pttensor_to_images(data):
images = data.numpy().transpose(0,2,3,1) images = data.numpy().transpose(0,2,3,1)
except RuntimeError: except RuntimeError:
images = data.detach().numpy().transpose(0,2,3,1) images = data.detach().numpy().transpose(0,2,3,1)
return images
def gaussian_kuwahara(data, batch_size=64, radius=5): def gaussian_kuwahara(data, batch_size=64, radius=5):
@ -97,7 +100,7 @@ def threshold_filter(data, batch_size=64, threshold=0.5):
return torch.tensor(filtered_images).float() return torch.tensor(filtered_images).float()
def snap_colors(data, batch_size=64, quantizations=4) def snap_colors(data, batch_size=64, quantizations=4):
images = pttensor_to_images(data) images = pttensor_to_images(data)
filtered_images = np.ndarray((batch_size,28,28,1)) filtered_images = np.ndarray((batch_size,28,28,1))

View File

@ -6,9 +6,13 @@ from torchvision import datasets, transforms
import numpy as np import numpy as np
from scipy import stats from scipy import stats
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import cv2
from mnist import Net from mnist import Net
from pykuwahara import kuwahara
import defense_filters
TESTED_STRENGTH_COUNT = 5
MAX_EPSILON = 0.3 MAX_EPSILON = 0.3
EPSILON_STEP = 0.025 EPSILON_STEP = 0.025
@ -30,7 +34,6 @@ print("CUDA Available: ", torch.cuda.is_available())
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
model = Net().to(device) model = Net().to(device)
print(type(model))
model.load_state_dict(torch.load(pretrained_model, map_location=device)) model.load_state_dict(torch.load(pretrained_model, map_location=device))
@ -71,19 +74,18 @@ def denorm(batch, mean=[0.1307], std=[0.3081]):
def test(model, device, test_loader, epsilon): def test(model, device, test_loader, epsilon):
# Original dataset correct classifications # Original dataset correct classifications
orig_correct = 0 orig_correct = 0
# Attacked dataset correct classifications # Attacked dataset correct classifications
unfiltered_correct = 0 unfiltered_correct = 0
kuwahara_correct = 0
bilateral_correct = 0
gaussian_blur_correct = 0
random_noise_correct = 0
snap_color_correct = 0
one_bit_correct = 0
plurality_correct = 0
adv_examples = []
# Attacked, filtered dataset correct classifications
filtered_correct_counts = []
test_step = 0
for data, target in test_loader: for data, target in test_loader:
print("[" + "="*int(20*test_step/len(test_loader)) + " "*(20 - int(20*test_step/len(test_loader))) + "]", f"{100*test_step/len(test_loader)}%", end='\r')
test_step += 1
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
data.requires_grad = True data.requires_grad = True
@ -111,34 +113,31 @@ def test(model, device, test_loader, epsilon):
# Reapply normalization # Reapply normalization
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data) perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
# Filter the attacked image # Evaluate the model on the attacked image
kuwahara_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="kuwahara")
bilateral_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="bilateral")
gaussian_blur_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="gaussian_blur")
random_noise_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="noise")
snap_color_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="snap_color")
one_bit_data = filtered(perturbed_data_normalized, len(perturbed_data_normalized), filter="1-bit")
# evaluate the model on the attacked and filtered images
output_unfiltered = model(perturbed_data_normalized) output_unfiltered = model(perturbed_data_normalized)
output_kuwahara = model(kuwahara_data)
output_bilateral = model(bilateral_data)
output_gaussian_blur = model(gaussian_blur_data)
output_random_noise = model(random_noise_data)
output_snap_color = model(snap_color_data)
output_one_bit = model(one_bit_data)
# Get the predicted class from the model for each case # Evaluate performance for
for i in range(TESTED_STRENGTH_COUNT):
strength = 2*i + 1
# Apply the filter with the specified strength
filtered_input = defense_filters.gaussian_kuwahara(perturbed_data_normalized, batch_size=len(perturbed_data_normalized), radius=strength)
# Evaluate the model on the filtered images
filtered_output = model(filtered_input)
# Get the predicted classification
filtered_pred = filtered_output.max(1, keepdim=True)[1]
# Count up correct classifications
if filtered_pred.item() == target.item():
if len(filtered_correct_counts) == i:
filtered_correct_counts.append(1)
else:
filtered_correct_counts[i] += 1
# Get the predicted classification
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1] unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
kuwahara_pred = output_kuwahara.max(1, keepdim=True)[1]
bilateral_pred = output_bilateral.max(1, keepdim=True)[1]
gaussian_blur_pred = output_gaussian_blur.max(1, keepdim=True)[1]
random_noise_pred = output_random_noise.max(1, keepdim=True)[1]
snap_color_pred = output_snap_color.max(1, keepdim=True)[1]
one_bit_pred = output_one_bit.max(1, keepdim=True)[1]
predictions = [unfiltered_pred.item(), kuwahara_pred.item(), bilateral_pred.item(), gaussian_blur_pred.item(), random_noise_pred.item(), snap_color_pred.item(), one_bit_pred.item()]
plurality_pred = stats.mode(predictions, keepdims=True)[0]
# Count up correct classifications for each case # Count up correct classifications for each case
if orig_pred.item() == target.item(): if orig_pred.item() == target.item():
orig_correct += 1 orig_correct += 1
@ -146,85 +145,39 @@ def test(model, device, test_loader, epsilon):
if unfiltered_pred.item() == target.item(): if unfiltered_pred.item() == target.item():
unfiltered_correct += 1 unfiltered_correct += 1
if kuwahara_pred.item() == target.item():
kuwahara_correct += 1
if bilateral_pred.item() == target.item():
bilateral_correct += 1
if gaussian_blur_pred.item() == target.item():
gaussian_blur_correct += 1
if random_noise_pred.item() == target.item():
random_noise_correct += 1
if snap_color_pred.item() == target.item():
snap_color_correct += 1
if one_bit_pred.item() == target.item():
one_bit_correct += 1
if plurality_pred == target.item():
plurality_correct += 1
# Calculate the overall accuracy of each case # Calculate the overall accuracy of each case
orig_acc = orig_correct/float(len(test_loader)) orig_acc = orig_correct/float(len(test_loader))
unfiltered_acc = unfiltered_correct/float(len(test_loader)) unfiltered_acc = unfiltered_correct/float(len(test_loader))
kuwahara_acc = kuwahara_correct/float(len(test_loader))
bilateral_acc = bilateral_correct/float(len(test_loader)) filtered_accuracies = []
gaussian_blur_acc = gaussian_blur_correct/float(len(test_loader)) for correct_count in filtered_correct_counts:
random_noise_acc = random_noise_correct/float(len(test_loader)) filtered_accuracies.append(correct_count/float(len(test_loader)))
snap_color_acc = snap_color_correct/float(len(test_loader))
one_bit_acc = one_bit_correct/float(len(test_loader))
plurality_acc = plurality_correct/float(len(test_loader))
print(f"====== EPSILON: {epsilon} ======") print(f"====== EPSILON: {epsilon} ======")
print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}") print(f"Clean (No Filter) Accuracy = {orig_correct} / {len(test_loader)} = {orig_acc}")
print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}") print(f"Unfiltered Accuracy = {unfiltered_correct} / {len(test_loader)} = {unfiltered_acc}")
print(f"Kuwahara Filter Accuracy = {kuwahara_correct} / {len(test_loader)} = {kuwahara_acc}") for i in range(TESTED_STRENGTH_COUNT):
print(f"Bilateral Filter Accuracy = {bilateral_correct} / {len(test_loader)} = {bilateral_acc}") strength = 2*i + 1
print(f"Gaussian Blur Accuracy = {gaussian_blur_correct} / {len(test_loader)} = {gaussian_blur_acc}") print(f"Gaussian Kuwahara (strength = {strength}) = {filtered_correct_counts[i]} / {len(test_loader)} = {filtered_accuracies[i]}")
print(f"Random Noise Accuracy = {random_noise_correct} / {len(test_loader)} = {random_noise_acc}")
print(f"Snapped Color Accuracy = {snap_color_correct} / {len(test_loader)} = {snap_color_acc}")
print(f"1 Bit Accuracy = {one_bit_correct} / {len(test_loader)} = {one_bit_acc}")
print(f"Plurality Vote Accuracy = {plurality_correct} / {len(test_loader)} = {plurality_acc}")
return unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc
return unfiltered_acc, filtered_accuracies
unfiltered_accuracies = [] unfiltered_accuracies = []
kuwahara_accuracies = [] filtered_accuracies = []
bilateral_accuracies = []
gaussian_blur_accuracies = []
random_noise_accuracies = []
snap_color_accuracies = []
one_bit_accuracies = []
pluality_vote_accuracies = []
print(f"Model: {pretrained_model}")
for eps in epsilons: for eps in epsilons:
unfiltered_acc, kuwahara_acc, bilateral_acc, gaussian_blur_acc, random_noise_acc, snap_color_acc, one_bit_acc, plurality_acc = test(model, device, test_loader, eps) unfiltered_accuracy, filtered_accuracy = test(model, device, test_loader, eps)
unfiltered_accuracies.append(unfiltered_acc) unfiltered_accuracies.append(unfiltered_accuracy)
kuwahara_accuracies.append(kuwahara_acc) filtered_accuracies.append(filtered_accuracy)
bilateral_accuracies.append(bilateral_acc)
gaussian_blur_accuracies.append(gaussian_blur_acc)
random_noise_accuracies.append(random_noise_acc)
snap_color_accuracies.append(snap_color_acc)
one_bit_accuracies.append(one_bit_acc)
pluality_vote_accuracies.append(plurality_acc)
# Plot the results # Plot the results
plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy") plt.plot(epsilons, unfiltered_accuracies, label="Attacked Accuracy")
plt.plot(epsilons, kuwahara_accuracies, label="Kuwahara Accuracy") for i in range(TESTED_STRENGTH_COUNT):
plt.plot(epsilons, bilateral_accuracies, label="Bilateral Accuracy") plt.plot(epsilons, filtered_accuracies[i], label=f"Gaussian Kuwahara (strength = {2*i + 1})")
plt.plot(epsilons, gaussian_blur_accuracies, label="Gaussian Blur Accuracy")
plt.plot(epsilons, random_noise_accuracies, label="Random Noise Accuracy")
plt.plot(epsilons, snap_color_accuracies, label="Snapped Color Accuracy")
plt.plot(epsilons, one_bit_accuracies, label="1-Bit Accuracy")
plt.plot(epsilons, pluality_vote_accuracies, label="Plurality Vote Accuracy")
plt.legend() plt.legend()
plt.show() plt.show()