Fixed error with counting correct predictions when using filters

This commit is contained in:
Aidan Sharpe 2024-04-11 15:11:32 -04:00
parent 7430ddac94
commit 90915efb7e

View File

@ -129,12 +129,10 @@ def test(model, device, test_loader, epsilon):
# Count up correct classifications
if filtered_pred.item() == target.item():
if len(filtered_correct_counts) == i:
filtered_correct_counts.append(1)
else:
while i >= len(filtered_correct_counts):
filtered_correct_counts.append(0)
filtered_correct_counts[i] += 1
# Get the predicted classification
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]