Day-26 手把手的手写辨识模型 0x1:资料集整理

  • 我们今天要使用的资料集就是在机器学习世界鼎鼎大名的 MNIST(Modified National Institute of Standards and Technology),我们都知道在学程序语言的过程中,我们都会先写一只程序叫做 'hello world' 来和这个程序世界打声招呼,那 MNIST 就是机器学习领域的 hello world~
  • 此资料集在 1999 年就发布了,至今历经了几次整理之後,可以看到已经在 kaggle 上面常驻,属於此领域的敲门砖~
  • 那我们今天当然就要把资料好好的使用啦~ 今天我们来介绍两种不同取得这个资料的方式吧~

Kaggle

Download

  • 那要取得 Kaggle 的 MNIST 其实非常简单,只要在 Google 搜寻 Kaggle MNIST 就可以看到相对应的网页了,我们也附在这里
  • 那可以看到这份资料被放在 CSV 里面去了,这时你应该会问,不是手写图片吗?怎麽是 CSV 档案?说好的照片呢?你四八四骗我 OAO
  • 我没有骗你啦~这份 CSV 里面就已经放好了整个手写图片的资讯了欧~让我来解释给你听
  • 图片是由 Pixel 所组成,也就是一个一个小方格子,那这张图片其实是一张 28*28 大小的图片,也就是说他的宽有 28 个 pixel ,高也有 28 个 pixel,所以整张图片就是 784 个 pixel 组成的,因此 Kaggle 就把图片资料分成一个一个 pixel 放到 CSV 中,因此我们可以发现 train.csv 这份档案里面每一行都是一张图片,然後共有 785 栏,分别就是 1~784 个 pixel 的颜色状况 + 1 栏答案
  • 所以基於这样的状况,我们来看看我们的 Dataset 要怎麽撰写吧~

Dataset

  • 我们先解释一下我们怎麽安排资料的状况,大致上目录结构长这样
    ├── mnist.py
    ├── data
    ├ ├── train.csv
    └ └── test.csv
  • 所以我们要读取资料的资料位置会是 path = './data/train.csv'
  • 那我们写成 Dataset 的话会怎麽写?
# 0) data import and set as pytorch dataset
class MnistDataset(Dataset):

    # data loading
    def __init__(self, path):
        xy = np.loadtxt(path, delimiter=',', dtype=np.float32, skiprows=1)
        self.x = torch.from_numpy(xy[:, 1:])
        self.y = torch.from_numpy(xy[:, [0]].flatten().astype(np.longlong))
        self.n_samples = xy.shape[0]

    # working for indexing
    def __getitem__(self, index):
        
        return self.x[index], self.y[index]

    # return the length of our dataset
    def __len__(self):
        
        return self.n_samples

train_path = './data/train.csv'
test_path = './data/test.csv'
train_dataset = MnistDataset(train_path)
test_dataset = MnistDataset(test_path)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
  • 由於我们两份资料的档案不同,因此我们可以传递位置来达到读取不同档案,那由於我们是 CSV file,里面的资料分离是由 , 去做切分的,因此我们设定我们的 delimiter 为 ,
  • 那会看到里面有 skiprows,是因为我们 CSV 中的第一行是名称,那些并不是特徵,因此我们跳过读取
  • 那这边要注意由於资料现在都是摊平的 28 * 28 资料,但是我们使用 CNN 时需要是二维有通道的图片资料,因此我们的资料整理要包含两个部分,一个部分就是将 784 的一维资料调整成 28 * 28 的二维资料,并且增加第三个维度把通道数给加进去,那这边用到 .unsqueeze(0) 去作为度的扩展,利用 .view(28, 28) 去 reshape 资料
# 0) data import and set as pytorch dataset
class MNISTDataset(Dataset):

    # data loading
    def __init__(self, path, transform=None):
        xy = np.loadtxt(path, delimiter=',', dtype=np.float32, skiprows=1)
        self.x = torch.from_numpy(xy[:, 1:])
        self.y = torch.from_numpy(xy[:, [0]].flatten().astype(np.longlong))
        self.n_samples = xy.shape[0]

        self.transform = transform

    # working for indexing
    def __getitem__(self, index):
        sample = self.x[index], self.y[index]

        if self.transform:
            sample = self.transform(sample)
        
        return sample

    # return the length of our dataset
    def __len__(self):
        
        return self.n_samples


class ToImage:
    def __call__(self, sample):
        inputs, targets = sample

        return inputs.view(28, 28).unsqueeze(0), targets


train_path = './data/train.csv'
test_path = './data/test.csv'
train_dataset = MNISTDataset(train_path, transform=ToImage())
test_dataset = MNISTDataset(test_path, transform=ToImage())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
  • 我们特别撰写了一个去改变资料状态的转换器,去转换我们的资料成 1 * 28 * 28 的资料
  • 这是当我们要读取 kaggle 资料集的方式,下面让我们看看利用 Pytorch 内建资料集的使用方式

Pytorch

  • Pytorch 其实也有提供像 sklearn 一样的 Datasets 提供下载,那想要下载这些资料需要用到 Pytorch 的 torchvision 函式库,那下面就让我们示范怎麽去使用 Pytorch 的 Dataset 去取得 MNIST 的资料吧~
import torchvision
import torchvision.transforms as transforms

# MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
    transform=transforms.ToTensor(), download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', train=False,
    transform=transforms.ToTensor(), download=False)
  • 我们来一一解释上面的 Code 的参数部分,首先我们要取得资料就是利用 torchvision.datasets.资料集名称 去取得,那分成 training data 跟 testing data 的方式非常简单,Pytorch 非常贴心的提供一个参数叫做 train,来区分资料集中的 training data 跟 testing data
  • 另外资料在做训练之前,还是需要做下载,因此我们需要给予 download 参数确定是否需要下载资料,那只要有资料之後,可以把 download 改成 False,也就是说资料下载只需要一次
  • root 参数说明了资料位置,那我们下面画一下资料的结构图
    ├── mnist.py
    ├── data
    ├ ├── MNIST
    ├ ├ ├── train-images-idx3-ubyte
    ├ ├ ├── train-labels-idx1-ubyte
    ├ ├ ├── t10k-images-idx3-ubyte
    └ └ └── t10k-labels-idx1-ubyte
    Pytorch 会在 root 目录里面下载相对应的资料集资料,并以资料集名称作为目录
  • 最後一个参数是 transform,这个参数定义了资料的转换,例如说今天的资料是图片资料,我们可以利用 transform 来做 img to tensor 的转换,因此我们可以看到在这边我们就挂了 transforms.ToTensor() 确保资料型态是 Tensor
  • 那这边可以发现我们资料的挂载变得比较简单轻松,但是这毕竟是因为资料集有在 Pytorch 的资料库中,因此我们还是主要以读取外部资料集的方式做练习
  • 那下面让我们建立我们的 Dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
    shuffle=True)

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
    shuffle=True)
  • 这边就是将对应的 Dataset 放给对应的 Dataloader 并订定 Batch size,这边比较特殊的参数是 shuffle,shuffle 参数允许我们在使用资料前优先打乱资料,那这个参数很好用,因为大部分现成的资料集内部资料都是有整理过的,也就是有序排列,也就是说,如果不先 shuffle 过资料的话,容易造成模型训练的偏差,因此如果需要的话,shuffle 参数都是可以随手挂载的,对训练会有帮助

每日小结

  • 虽然 Pytorch Dataset 中也有 MNIST 资料集,但是我们还是提到了第一种资料读取方式是有原因的,大部分的现实问题一定不会有现成的资料集,因此如何读取使用这些资料集反而是我们应该学习的,也因此我们要熟悉自己撰写资料读取
  • 但是由於 kaggle 提供的 test data 并没有附赠答案,因此,我们还是必须在这边先使用 Pytorch 提供的资料集来做资料验证 QQ,所以明天的完整版将会使用 Pytorch dataset
  • Pytorch 提供了很多实用且方便的参数让我们在资料使用的过程中更加轻松,如何更好的利用他们都可以去看官方的文件去更加了解参数的使用,会对後续的训练有极大的帮助,在这边笔者都只提到一些常用的参数,还有更多参数可以去理解学习
  • 那我们今天把资料读取进来了,明天就让我们开始训练我们的 CNN 模型来面对这份资料集吧~

<<:  IT 铁人赛 k8s 入门30天 -- day27 Communicate Between Containers in the Same Pod Using a Shared Volume

>>:  予焦啦!结论与展望(二):铁人赛、正体中文科技写作与杂谈

Day17 PHP的常用函数-2:数组

数组 array(): 生成一个数组 range(): 创建并返回一个包含指定范围的元素的数组 co...

Day18 Redis架构实战-持久化RDB

Redis持久化 Redis是一个in-memory的data store,在记忆体中操作与储存让其...

Day10 React Hooks 小实作简单的计数器

今天实作一个很简单的计数器,按下按钮後数字会一直累加1 要先将useState 汇入 import ...

【Day17】物件结构与存取

物件宣告 物件内容为一个属性 (property)对应一个值 (value), 如果要在後方添加新的...

Day07 - Login to Ptt

今天来处理登入的流程。 送出登入的方式很简单,使用WebSocketClient的send方法即可:...