1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| import torchvision import torch.nn.functional as F from torchvision.transforms import transforms from torch import nn import torch import matplotlib.pyplot as plt from icecream import ic
preprocess = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), ])
cifar_10 = torchvision.datasets.CIFAR10('~/data/course_data/', download=False, transform=preprocess)
train_loader = torch.utils.data.DataLoader(cifar_10, batch_size=128, shuffle=True)
resnet = torchvision.models.resnet18(pretrained=True)
for param in resnet.parameters(): param.requires_grad = False feature_num = resnet.fc.in_features resnet.fc = nn.Linear(feature_num, 10)
ic(resnet(cifar_10[0][0].unsqueeze(0)))
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(resnet.parameters(), lr=1e-3, momentum=0.9)
epochs = 2
losses = [] """ return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) ic| resnet(cifar_10[0][0].unsqueeze(0)): tensor([[-0.0763, -0.4537, 0.8168, 0.2136, -0.0465, 0.4844, -0.4026, 0.8763, -0.7048, -0.7375]], grad_fn=<AddmmBackward>) """
for epoch in range(epochs): epoch_loss = 0 for i, (images, labels) in enumerate(train_loader): ic(epoch, i) output = resnet(images) loss = criterion(output, labels) optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if i > 0: print('Epoch: {} batch:{}, loss ==> {}'.format(epoch, i, epoch_loss / i))
losses.append(epoch_loss / i) """ ic| epoch: 0, i: 0 ic| epoch: 0, i: 1 ic| epoch: 0, i: 2 Epoch: 0 batch:1, loss ==> 5.118020296096802 ic| epoch: 0, i: 3 Epoch: 0 batch:2, loss ==> 3.8235710859298706 ic| epoch: 0, i: 4 ... ic| epoch: 0, i: 203 Epoch: 0 batch:202, loss ==> 1.4433288293899875 ... """
plt.plot(losses) plt.show() """ Because the last time is too long to run, the losses are not assigned """
|