gpgpu-sem-2/lab1-pytorch-cifar10/test.py

135 lines
3.4 KiB
Python

import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from alexnet import AlexNet, CIFAR10_NUM_CLASSES
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
NET_SAVE_PATH = "./cifar10_alexnet.pth"
device: torch.device
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
trainset: torchvision.datasets.CIFAR10
testset: torchvision.datasets.CIFAR10
trainloader: torch.utils.data.DataLoader
testloader: torch.utils.data.DataLoader
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
def load_data():
global trainset, trainloader, testset, testloader
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
def imshow(img):
if img.is_cuda:
img = img.cpu()
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def main():
global device
device = torch.device("cuda:0")
print("Available device:", device)
load_data()
net = AlexNet(CIFAR10_NUM_CLASSES).to(device)
net.load_state_dict(torch.load(NET_SAVE_PATH))
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
imshow(torchvision.utils.make_grid(images))
print("GroundTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(4)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(4)))
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(
f"Accuracy of the network on the 10000 test images: {100 * correct // total} %"
)
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predictions = torch.max(outputs, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")
if __name__ == "__main__":
main()