Debugged code to run in background

This commit is contained in:
Adog64 2024-04-23 17:12:30 -04:00
parent 541a3b0850
commit 5c911219b9
3 changed files with 10 additions and 5 deletions

View File

@ -128,7 +128,7 @@ def filtered(data, batch_size=64, strength=0, filter="gaussian_blur"):
bits = 2**strength bits = 2**strength
return bit_depth(data, batch_size, bits) return bit_depth(data, batch_size, bits)
elif filter == "random_noise": elif filter == "random_noise":
intensity == 0.0005*(2*strength + 1) intensity = 0.0005*(2*strength + 1)
return random_noise(data, batch_size, intensity) return random_noise(data, batch_size, intensity)
else: else:
strength = (2*strength + 1) strength = (2*strength + 1)

View File

@ -85,7 +85,7 @@ def test(model, device, test_loader, epsilon, filter):
test_step = 0 test_step = 0
for data, target in test_loader: for data, target in test_loader:
print("[" + "="*int(20*test_step/len(test_loader)) + " "*(20 - int(20*test_step/len(test_loader))) + "]", f"{100*test_step/len(test_loader)}%", end='\r') print(filter, f"Epsilon: {epsilon}", "[" + "="*int(1 + 20*test_step/len(test_loader)) + " "*(20 - int(20*test_step/len(test_loader))) + "]", f"{100*test_step/len(test_loader)}%", end='\r')
test_step += 1 test_step += 1
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
@ -163,18 +163,23 @@ def test(model, device, test_loader, epsilon, filter):
return unfiltered_acc, filtered_accuracies return unfiltered_acc, filtered_accuracies
accurracies = {} accuracies = {}
filters = ("gaussian_blur", "gaussian_kuwahara", "mean_kuwahara", "random_noise", "bilateral_filter", "bit_depth", "threshold_filter") filters = ("gaussian_blur", "gaussian_kuwahara", "mean_kuwahara", "random_noise", "bilateral_filter", "bit_depth", "threshold_filter")
for filter in filters: for filter in filters:
for eps in epsilons: for eps in epsilons:
unfiltered_accuracy, filtered_accuracy = test(model, device, test_loader, eps, filter) unfiltered_accuracy, filtered_accuracy = test(model, device, test_loader, eps, filter)
if len(accuracies["unfiltered"] < len(epsilons): if list(accuracies.keys()).count("unfiltered") == 0:
accuracies["unfiltered"] = []
if len(accuracies["unfiltered"]) < len(epsilons):
accuracies["unfiltered"].append(unfiltered_accuracy) accuracies["unfiltered"].append(unfiltered_accuracy)
if list(accuracies.keys()).count(filter) == 0:
accuracies[filter] = []
accuracies[filter].append(filtered_accuracy) accuracies[filter].append(filtered_accuracy)
json.dump(accuracies, "fgsm_accuracies.json") json.dump(accuracies, "fgsm_mnist_accuracies.json")
# Plot the results # Plot the results
#plt.figure(figsize=(16,9)) #plt.figure(figsize=(16,9))