Adversarial-Machine-Learnin.../Filter_Analysis/cifar10.py

111 lines
3.6 KiB
Python

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import dla
def train(model, trainloader, optimizer, epoch):
running_loss = 0.0
for i, [data, target] in enumerate(trainloader, 0):
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(data)
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
def test(model, testloader, classes):
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data, target in testloader:
# calculate outputs by running images through the network
outputs = model(data)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
# again no gradients needed
with torch.no_grad():
for data, target in testloader:
outputs = model(data)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(target, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# print accuracy for each class
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
def main():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = dla.DLA().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(14):
train(model, trainloader, optimizer, epoch)
test(model, testloader, classes)
PATH = './cifar_net.pth'
torch.save(model.state_dict(), PATH)
if __name__ == "__main__":
main()