diff --git a/Filter_Analysis/cifar10.py b/Filter_Analysis/cifar10.py index 7cc6c64..6c1dd29 100644 --- a/Filter_Analysis/cifar10.py +++ b/Filter_Analysis/cifar10.py @@ -10,9 +10,37 @@ import torch.optim as optim import torch.nn as nn import torch.nn.functional as F -import dla +#import dla + +EPOCHS = 200 + +class CifarCNN(nn.Module): + def __init__(self): + super(CifarCNN, self).__init__() + self.conv1 = nn.Conv2d(3, 96, 3, 1) + self.conv2 = nn.Conv2d(96, 192, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(37632, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x,2) + x = self.dropout1(x) + x = torch.flatten(x,1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + def train(model, trainloader, optimizer, epoch): running_loss = 0.0 for i, [data, target] in enumerate(trainloader, 0): @@ -21,9 +49,8 @@ def train(model, trainloader, optimizer, epoch): optimizer.zero_grad() # forward + backward + optimize - outputs = model(data) - criterion = nn.CrossEntropyLoss() - loss = criterion(outputs, target) + output = model(data) + loss = F.nll_loss(output, target) loss.backward() optimizer.step() @@ -42,9 +69,9 @@ def test(model, testloader, classes): with torch.no_grad(): for data, target in testloader: # calculate outputs by running images through the network - outputs = model(data) + output = model(data) # the class with the highest energy is what we choose as prediction - _, predicted = torch.max(outputs.data, 1) + _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() @@ -58,8 +85,8 @@ def test(model, testloader, classes): # again no gradients needed with torch.no_grad(): for data, target in testloader: - outputs = model(data) - _, predictions = torch.max(outputs, 1) + output = model(data) + _, predictions = torch.max(output, 1) # collect the correct predictions for each class for label, prediction in zip(target, predictions): if label == prediction: @@ -94,15 +121,15 @@ def main(): 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - model = dla.DLA().to(device) + model = CifarCNN().to(device) optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - for epoch in range(14): + for epoch in range(EPOCHS): train(model, trainloader, optimizer, epoch) test(model, testloader, classes) - PATH = './cifar_net.pth' + PATH = './cifar_cnn.pth' torch.save(model.state_dict(), PATH)