Lab 1 CIFAR10 initial done

This commit is contained in:
Andrew 2023-03-01 16:06:46 +07:00
parent 671e8d40e6
commit c75e27a36e
6 changed files with 414 additions and 0 deletions

View file

@ -0,0 +1,135 @@
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from alexnet import AlexNet, CIFAR10_NUM_CLASSES
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
NET_SAVE_PATH = "./cifar10_alexnet.pth"
device: torch.device
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
trainset: torchvision.datasets.CIFAR10
testset: torchvision.datasets.CIFAR10
trainloader: torch.utils.data.DataLoader
testloader: torch.utils.data.DataLoader
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
def load_data():
global trainset, trainloader, testset, testloader
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
def imshow(img):
if img.is_cuda:
img = img.cpu()
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def main():
global device
device = torch.device("cuda:0")
print("Available device:", device)
load_data()
net = AlexNet(CIFAR10_NUM_CLASSES).to(device)
net.load_state_dict(torch.load(NET_SAVE_PATH))
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
imshow(torchvision.utils.make_grid(images))
print("GroundTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(4)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(4)))
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(
f"Accuracy of the network on the 10000 test images: {100 * correct // total} %"
)
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predictions = torch.max(outputs, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")
if __name__ == "__main__":
main()