diff --git a/Filter_Analysis/fgsm.py b/Filter_Analysis/fgsm.py index cbae825..4e7254a 100644 --- a/Filter_Analysis/fgsm.py +++ b/Filter_Analysis/fgsm.py @@ -129,11 +129,9 @@ def test(model, device, test_loader, epsilon): # Count up correct classifications if filtered_pred.item() == target.item(): - if len(filtered_correct_counts) == i: - filtered_correct_counts.append(1) - else: - filtered_correct_counts[i] += 1 - + while i >= len(filtered_correct_counts): + filtered_correct_counts.append(0) + filtered_correct_counts[i] += 1 # Get the predicted classification unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]