[DAY 13] CNN的实作以及 Classification 的应用例子

前言


我们已经知道了组成一个简单的 CNN 做 classification 所需要的部件了,那麽接着我们就需要使用 Pytorch 的 Tool 来建构一个真实能够动的 Model来看看效果了~

CNN 实作

我们拿 Pytorch 官网的 tutorial 来进行说明,一般来说我们要建构一个 CNN 的 model 并进行 training 以及 Testing , 我们需要完成以下几件事:

  1. 载入 Dataset 并进行预处理
    (Load and normalizing the CIFAR10 training and test datasets using torchvision)
  2. 定义我们的神经网路 (Model)
    (Define a Convolution Neural Network)
  3. 定义 Loss Function
    (Define a loss function)
  4. 使用 Training dataset 来训练我们的 Model
    (Train the network on the training data)
  5. 使用 Testing dataset 来测试我们训练好的 Model
    (Test the network on the test data)

为了方便给大家练习,我们一样使用之前介绍过的 Colab 来练习,相关文章可以参考这个~

载入 Dataset 并进行预处理

我们这一个 Part 其实之前在 [DAY 05] 从头训练大Model?想多了 : Torchvision 简介 这一章中的 "使用 Dataset" 这一 Part 介绍过了~所以这边我们就只贴上 Code 吧 :

import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

如果你想看看现在 load 进来的 Dataset长啥样,可使用下方程序:

import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

定义我们的神经网路 (Model)

话先不多说,我们先贴上完整的 Code ,再逐步解释每个 Part 跟 Function 是啥意思:

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

定义 Loss Function

使用 Training dataset 来训练我们的 Model

使用 Testing dataset 来测试我们训练好的 Model


<<:  LineBot - 申请

>>:  DAY28 MongoDB Atlas 付费监控内容

Day17:比大小

记得初学Java的时候,若要对List进行排序,可以使用Collections的静态方法sort()...

D27 第十四周 (回忆篇)

支线任务:共笔部落格切版 礼拜一的时候终於把留言版做完了,接着是弄共笔部落格的文章列表样板,花了一两...

Day11|【Git】档案管理 - 重新命名档案 git mv

延续上篇的说明,在 Git 的世界,任何动作对 Git 来说都可以视为一个「修改」的动作。因此这篇要...

[D07] OpenCV 基本的影像调整

我们已经掌握了基本的影像读取、显示以及显示,但不是每张照片都刚刚好是我们想要的样子,所以接下来,来看...

[DAY13]影片及音档

#影片及音档下载过程极慢所以建议少用 VideoSendMessage video_message ...