Fixed error with counting correct predictions when using filters
This commit is contained in:
parent
7430ddac94
commit
90915efb7e
@ -129,12 +129,10 @@ def test(model, device, test_loader, epsilon):
|
|||||||
|
|
||||||
# Count up correct classifications
|
# Count up correct classifications
|
||||||
if filtered_pred.item() == target.item():
|
if filtered_pred.item() == target.item():
|
||||||
if len(filtered_correct_counts) == i:
|
while i >= len(filtered_correct_counts):
|
||||||
filtered_correct_counts.append(1)
|
filtered_correct_counts.append(0)
|
||||||
else:
|
|
||||||
filtered_correct_counts[i] += 1
|
filtered_correct_counts[i] += 1
|
||||||
|
|
||||||
|
|
||||||
# Get the predicted classification
|
# Get the predicted classification
|
||||||
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
|
unfiltered_pred = output_unfiltered.max(1, keepdim=True)[1]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user