DAY15:玉山人工智慧挑战赛-中文手写字辨识(Pytorch 自订义资料集)

资料扩增

  • 我们组的资料扩增这部分,因为第一次比赛,这个方法效果没有到非常好,采取的是用mask的方式,让图档多加一些遮蔽物,如下图。详细操作参考组员的分享(传送门)
  • 增加完我们的图片总数量约为19万张。

Pytorch自定义资料集

  • 我们先定义一个alphabet,它代表的是我们的800个字的位置。
img_names = os.listdir(data_path)
source = img_names
alphabet = ''.join(source)

  • 因为我们到时候会将图档用PIL的Image读取出来,所以先将图档和对应的label组成一个list。
# 读取图档,并转换大小为80*80,以及转换成RGB
def img_loader(img_path):
    image = Image.open(img_path)
    img = image.resize((80, 80),Image.ANTIALIAS) #resize image with high-quality
    return img.convert('RGB')
# 将图档与label对应,丢入自定义的资料集内
def make_dataset(data_path, alphabet, num_class):
    samples = []
    for i in os.listdir(data_path):
        for j in os.listdir(data_path + '/' + i):
            img_path = data_path + '/' + i + '/' + j
            target_str = j.split('.')[0][-1]
            vec = [0] * 800
            vec[alphabet.find(target_str)] = 1
            target = vec
            samples.append((img_path, target))
    return samples

例如这个字是"不",由alphabet的位置可以看到alphabet[3]的位置是"不",故在alphabet[3]的位置为1,代表他的label,其余位置皆为0。

  • torch.utils.data.Dataset,是一个自定义资料集的框架。

    • __ init __()

      • 负责做一个初始化的动作,我们先定义我们要的东西:
        1. data_path:我们要读取资料集的路径。
        2. num_class:我们要预测的种类数量(800类)
        3. transform:对於图片是否进行处理,这里设定None,不对读取进来的图片作处理。
        4. target_transform:对标签做处理,这里我们也都处理好了,不对标签做处理,设定为None。
        5. alphabet:我们要预测的所有字,将它变成str。
        6. samples:是我们用make_data弄成的图片与label对应的list。
      def __init__(self, data_path, num_class=800,transform=None,target_transform=None, alphabet=alphabet):
          super(Dataset, self).__init__()
          self.data_path = data_path
          self.num_class = num_class
          self.transform = transform
          self.target_transform = target_transform
          self.alphabet = alphabet
          self.samples = make_data.set(self.data_path, self.alphabet)
      
    • __ len __ ()

      • 返回list中的长度,也就是你的资料的笔数。
      def __len__(self):
          return len(self.samples)
      
    • __ getitem __ ()

      • 使资料集可以节省内存,资料集为dataset,而__ len __ ()返回的数字n,使的dataset[n]的图片能被读取,需要时才将图片读取,所以可以节省内存。返回值一个图片样本及标签。
      def __getitem__(self, index):
          img_path, target = self.samples[index]
          img = img_loader(img_path)
          if self.transform is not None:
              img = self.transform(img)
          if self.target_transform is not None:
              target = self.target_transform(target)
          return img, torch.Tensor(target) # 在torch里面,array都要转成Tensor型式
      
    • 完整程序码

    class CaptchaData(Dataset):
      def __init__(self, data_path, num_class=800,
                   transform=None, target_transform=None, alphabet=alphabet):
          super(Dataset, self).__init__()
          self.data_path = data_path
          self.num_class = num_class
          self.transform = transform
          self.target_transform = target_transform
          self.alphabet = alphabet
          self.samples = make_dataset(self.data_path, self.alphabet)
    
      def __len__(self):
          return len(self.samples)
    
      def __getitem__(self, index):
          img_path, target = self.samples[index]
          img = img_loader(img_path)
          if self.transform is not None:
              img = self.transform(img)
          if self.target_transform is not None:
              target = self.target_transform(target)
          return img, torch.Tensor(target)
    
  • torch.utils.data.DataLoader

    • Dataset设置好後,DataLoader可以依照batch_size让我们取样,非常方便。

      from torchvision.transforms import Compose, ToTensor
      from torch.utils.data import DataLoader
      transforms = Compose([ToTensor()])
      train_dataset = CaptchaData(r'C:\Users\Frank\PycharmProjects\practice\mountain\data_final_20210530\official_in_800',transform=transforms)
      train_data_loader = DataLoader(train_dataset, batch_size=1, num_workers=0,
                                         shuffle=True, drop_last=True)
      
      for (data,label) in train_data_loader:
          print((data,label))
      

      我把batch_size设定为1,他一次就只取出一组图片样本及标签。

      • batch_size:一次要取多少样本。
      • num_workers:数据加载的子进程数,0则为主要进程。(这里有个小坑,我只要条不是0的时候,都会发生error,如果有知道的大神,帮我指点一下,小弟非常感谢。)
      • shuffle:若设定为True,指每个epoch取出样本的顺序都会不一样。
      • drop_last:若设定为True,则全部样本不能被batch_size整除时,最後一批会直接被删除;若为False,则最後一批会较小。
  • 除了自定义资料集以外,还有可以torchvision.datasets.ImageFolder来处理资料集,用法会在於你分好类别,他的资料夹名称就是他的label,而里面图片都属於这个label。


今日小结

  • 深度学习有很多很好玩的地方,但也有很多的坑,debug我都要找很久XDD,重点是东西太多,绝对学不完,而且很吃硬体设备。有时候会觉得自己好笨,都学不会,但看久了发现懂一点了,就又有动力继续往下学了,接触深度学习的朋友们,我们一起继续努力吧!

  • 小弟我是试着用自定义资料集来处理,原因只想练习以及可以更弹性的操作载入资料的动作。

  • 前面加载图片时我们把transforms设置为None,现在我们丢模型训练要对图片做transforms,他可以增加图片的多样性,例如:旋转、平移、变形等等,明天来跟大家分享torchvision很好用的套件transforms。


<<:  Day15:终於要进去新手村了-Javascript-isNaN函式

>>:  Day15-Vue SFC 单一元件档

Day26-不好意思,这里前方是一方通行啊!

嗨,各位早安 相信各位一定有听过"防火墙"这个东西吧? 之前我们前面讲过了ssh...

第14章:设定系统时区与时间

前言 本篇文章,是要设定系统的时区与时间,并维护时区与时间的设定。 设定本地时间与时区 作业系统的时...

Day4:如何使用Parrot_Security的CeWL工具收集特定网站的Wordlist

今天我们来讲一下如何使用Parrot_Security的CeWL工具来收集特定网站的Wordlist...

Day7 CSV档处理

在经历上一部函数与类别的摧残後,这两天就来教一些比较温和的程序吧~ 今天的影片内容为介绍常见的档案格...

[Day 30] 永和美食纪录-向日葵早午餐 国中店

前言 今天文章的标题完完全全打脸了笔者在 Day27 的结语,没想到在最後一天仍然还是介绍了早午餐给...