135 lines
3.4 KiB
Python
135 lines
3.4 KiB
Python
import torch
|
|
import torch.utils.data
|
|
import torch.nn as nn
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import torch.optim as optim
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from alexnet import AlexNet, CIFAR10_NUM_CLASSES
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("CUDA is not available")
|
|
|
|
|
|
NET_SAVE_PATH = "./cifar10_alexnet.pth"
|
|
|
|
device: torch.device
|
|
|
|
transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
)
|
|
|
|
batch_size = 4
|
|
|
|
trainset: torchvision.datasets.CIFAR10
|
|
testset: torchvision.datasets.CIFAR10
|
|
|
|
trainloader: torch.utils.data.DataLoader
|
|
testloader: torch.utils.data.DataLoader
|
|
|
|
classes = (
|
|
"plane",
|
|
"car",
|
|
"bird",
|
|
"cat",
|
|
"deer",
|
|
"dog",
|
|
"frog",
|
|
"horse",
|
|
"ship",
|
|
"truck",
|
|
)
|
|
|
|
|
|
def load_data():
|
|
global trainset, trainloader, testset, testloader
|
|
trainset = torchvision.datasets.CIFAR10(
|
|
root="./data", train=True, download=True, transform=transform
|
|
)
|
|
trainloader = torch.utils.data.DataLoader(
|
|
trainset, batch_size=batch_size, shuffle=True, num_workers=2
|
|
)
|
|
|
|
testset = torchvision.datasets.CIFAR10(
|
|
root="./data", train=False, download=True, transform=transform
|
|
)
|
|
testloader = torch.utils.data.DataLoader(
|
|
testset, batch_size=batch_size, shuffle=False, num_workers=2
|
|
)
|
|
|
|
|
|
def imshow(img):
|
|
if img.is_cuda:
|
|
img = img.cpu()
|
|
img = img / 2 + 0.5
|
|
npimg = img.numpy()
|
|
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
|
plt.show()
|
|
|
|
|
|
def main():
|
|
global device
|
|
device = torch.device("cuda:0")
|
|
print("Available device:", device)
|
|
|
|
load_data()
|
|
|
|
net = AlexNet(CIFAR10_NUM_CLASSES).to(device)
|
|
|
|
net.load_state_dict(torch.load(NET_SAVE_PATH))
|
|
|
|
dataiter = iter(testloader)
|
|
images, labels = next(dataiter)
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
imshow(torchvision.utils.make_grid(images))
|
|
print("GroundTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(4)))
|
|
|
|
outputs = net(images)
|
|
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(4)))
|
|
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
for data in testloader:
|
|
images, labels = data
|
|
images, labels = images.to(device), labels.to(device)
|
|
|
|
outputs = net(images)
|
|
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
print(
|
|
f"Accuracy of the network on the 10000 test images: {100 * correct // total} %"
|
|
)
|
|
|
|
correct_pred = {classname: 0 for classname in classes}
|
|
total_pred = {classname: 0 for classname in classes}
|
|
|
|
with torch.no_grad():
|
|
for data in testloader:
|
|
images, labels = data
|
|
images, labels = images.to(device), labels.to(device)
|
|
outputs = net(images)
|
|
_, predictions = torch.max(outputs, 1)
|
|
|
|
for label, prediction in zip(labels, predictions):
|
|
if label == prediction:
|
|
correct_pred[classes[label]] += 1
|
|
total_pred[classes[label]] += 1
|
|
|
|
for classname, correct_count in correct_pred.items():
|
|
accuracy = 100 * float(correct_count) / total_pred[classname]
|
|
print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|