gpgpu-sem-2/lab1-pytorch-cifar10/train.py

197 lines
5.5 KiB
Python

import torch
import torch.cuda
import torch.cuda.amp
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from alexnet import AlexNet, CIFAR10_NUM_CLASSES
import argparse
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
NET_SAVE_PATH = "./cifar10_alexnet.pth"
device: torch.device
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
trainset: torchvision.datasets.CIFAR10
testset: torchvision.datasets.CIFAR10
trainloader: torch.utils.data.DataLoader
testloader: torch.utils.data.DataLoader
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
def load_data():
global trainset, trainloader, testset, testloader
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2
)
def imshow(img):
if img.is_cuda:
img = img.cpu()
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def main(half_precision: bool = False):
global device
device = torch.device("cuda:0")
load_data()
net = AlexNet(num_classes=CIFAR10_NUM_CLASSES)
if half_precision:
net = net.half()
net = net.to(device)
scaler = torch.cuda.amp.grad_scaler.GradScaler()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print("Training")
for epoch in range(2):
running_loss = 0.0
for i, data in tqdm(
enumerate(trainloader, 0),
desc=f"Epoch {epoch+1}",
total=len(trainloader),
unit="batch",
):
inputs, labels = data
if half_precision:
inputs = inputs.half()
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs).to(device)
loss = criterion(outputs, labels)
if half_precision:
loss.backward()
optimizer.step()
else:
scaler.scale(loss).backward() # type: ignore
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
if i % 2000 == 1999:
tqdm.write(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
print("Finished Training")
torch.save(net.state_dict(), NET_SAVE_PATH)
dataiter = iter(testloader)
images, labels = next(dataiter)
if half_precision:
images = images.half()
images, labels = images.to(device), labels.to(device)
print("GroundTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(4)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(4)))
correct = 0
total = 0
with torch.no_grad():
for data in tqdm(
testloader,
desc="Measuring random guess accuracy",
unit="batch",
total=len(testloader),
):
images, labels = data
if half_precision:
images = images.half()
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(
f"Accuracy of the network on the 10000 test images: {100 * correct // total} %"
)
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in tqdm(
testloader,
desc="Measuring class accuracy",
unit="batch",
total=len(testloader),
):
images, labels = data
if half_precision:
images = images.half()
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predictions = torch.max(outputs, 1)
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")
if __name__ == "__main__":
# use argparse to add 'half' argument for training on half precision
parser = argparse.ArgumentParser()
parser.add_argument("--half", action="store_true", help="use half precision")
args = parser.parse_args()
if args.half:
print("Using half precision")
NET_SAVE_PATH = "./cifar10_alexnet_half.pth"
# now we can use args.half to check if we want to use half precision
main(half_precision=args.half)