import torch import torch.utils.data import torch.nn as nn import torchvision import torchvision.transforms as transforms import torch.optim as optim import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm from alexnet import AlexNet, CIFAR10_NUM_CLASSES if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available") NET_SAVE_PATH = "./cifar10_alexnet.pth" device: torch.device transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) batch_size = 4 trainset: torchvision.datasets.CIFAR10 testset: torchvision.datasets.CIFAR10 trainloader: torch.utils.data.DataLoader testloader: torch.utils.data.DataLoader classes = ( "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) def load_data(): global trainset, trainloader, testset, testloader trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) testloader = torch.utils.data.DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) def imshow(img): if img.is_cuda: img = img.cpu() img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() def main(): global device device = torch.device("cuda:0") print("Available device:", device) load_data() net = AlexNet(CIFAR10_NUM_CLASSES).to(device) net.load_state_dict(torch.load(NET_SAVE_PATH)) dataiter = iter(testloader) images, labels = next(dataiter) images, labels = images.to(device), labels.to(device) imshow(torchvision.utils.make_grid(images)) print("GroundTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(4))) outputs = net(images) _, predicted = torch.max(outputs, 1) print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(4))) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print( f"Accuracy of the network on the 10000 test images: {100 * correct // total} %" ) correct_pred = {classname: 0 for classname in classes} total_pred = {classname: 0 for classname in classes} with torch.no_grad(): for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = net(images) _, predictions = torch.max(outputs, 1) for label, prediction in zip(labels, predictions): if label == prediction: correct_pred[classes[label]] += 1 total_pred[classes[label]] += 1 for classname, correct_count in correct_pred.items(): accuracy = 100 * float(correct_count) / total_pred[classname] print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %") if __name__ == "__main__": main()