Fixed error with counting correct predictions when using filters
This commit is contained in:
parent
7430ddac94
commit
90915efb7e
@ -129,11 +129,9 @@ def test(model, device, test_loader, epsilon):
|
||||
|
||||
# Count up correct classifications
|
||||
if filtered_pred.item() == target.item():
|
||||
if len(filtered_correct_counts) == i:
|
||||
filtered_correct_counts.append(1)
|
||||
else:
|
||||
filtered_correct_counts[i] += 1
|
||||
|
||||
while i >= len(filtered_correct_counts):
|
||||
filtered_correct_counts.append(0)
|
||||
filtered_correct_counts[i] += 1
|
||||
|
||||
# Get the predicted classification
|
||||
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
|
||||
|
Loading…
Reference in New Issue
Block a user