import torch import torch.cuda import torch.cuda.amp import torch.utils.data import torch.nn as nn import torchvision import torchvision.transforms as transforms import torch.optim as optim import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm from alexnet import AlexNet, CIFAR10_NUM_CLASSES import argparse if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available") NET_SAVE_PATH = "./cifar10_alexnet.pth" device: torch.device transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) batch_size = 4 trainset: torchvision.datasets.CIFAR10 testset: torchvision.datasets.CIFAR10 trainloader: torch.utils.data.DataLoader testloader: torch.utils.data.DataLoader classes = ( "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) def load_data(): global trainset, trainloader, testset, testloader trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) testloader = torch.utils.data.DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) def imshow(img): if img.is_cuda: img = img.cpu() img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() def main(half_precision: bool = False): global device device = torch.device("cuda:0") load_data() net = AlexNet(num_classes=CIFAR10_NUM_CLASSES) if half_precision: net = net.half() net = net.to(device) scaler = torch.cuda.amp.grad_scaler.GradScaler() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) print("Training") for epoch in range(2): running_loss = 0.0 for i, data in tqdm( enumerate(trainloader, 0), desc=f"Epoch {epoch+1}", total=len(trainloader), unit="batch", ): inputs, labels = data if half_precision: inputs = inputs.half() inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = net(inputs).to(device) loss = criterion(outputs, labels) if half_precision: loss.backward() optimizer.step() else: scaler.scale(loss).backward() # type: ignore scaler.step(optimizer) scaler.update() running_loss += loss.item() if i % 2000 == 1999: tqdm.write(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") running_loss = 0.0 print("Finished Training") torch.save(net.state_dict(), NET_SAVE_PATH) dataiter = iter(testloader) images, labels = next(dataiter) if half_precision: images = images.half() images, labels = images.to(device), labels.to(device) print("GroundTruth: ", " ".join(f"{classes[labels[j]]:5s}" for j in range(4))) outputs = net(images) _, predicted = torch.max(outputs, 1) print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}" for j in range(4))) correct = 0 total = 0 with torch.no_grad(): for data in tqdm( testloader, desc="Measuring random guess accuracy", unit="batch", total=len(testloader), ): images, labels = data if half_precision: images = images.half() images, labels = images.to(device), labels.to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print( f"Accuracy of the network on the 10000 test images: {100 * correct // total} %" ) correct_pred = {classname: 0 for classname in classes} total_pred = {classname: 0 for classname in classes} with torch.no_grad(): for data in tqdm( testloader, desc="Measuring class accuracy", unit="batch", total=len(testloader), ): images, labels = data if half_precision: images = images.half() images, labels = images.to(device), labels.to(device) outputs = net(images) _, predictions = torch.max(outputs, 1) for label, prediction in zip(labels, predictions): if label == prediction: correct_pred[classes[label]] += 1 total_pred[classes[label]] += 1 for classname, correct_count in correct_pred.items(): accuracy = 100 * float(correct_count) / total_pred[classname] print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %") if __name__ == "__main__": # use argparse to add 'half' argument for training on half precision parser = argparse.ArgumentParser() parser.add_argument("--half", action="store_true", help="use half precision") args = parser.parse_args() if args.half: print("Using half precision") NET_SAVE_PATH = "./cifar10_alexnet_half.pth" # now we can use args.half to check if we want to use half precision main(half_precision=args.half)