【7】Dataset 的三个API : Shuffle Batch Repeat 如果使用顺序不同会产生的影响

Colab连结

今天的主题比较特殊一些,要来探讨 tensorflow 中的 Dataset api : shuffle, batch 和 repeat 的顺序,在我一开始使用这个API时,我完全没想到他的顺序会完全影响到训练的结果而踩了好大的一坑。

Shuffle:
顾名思义,就是用来打乱资料集的API,只是需要注意的是在使用此 API 时,必须给予 buffer_size ,其用途是执行 shuffle 时,他并不是把全部的资料做 shuffle ,而是只会把前N个资料做 shuffle,这个N的数量就是 buffer_size 。

Batch:
前几天的实验已经用过了,就是将原本分散一笔一笔的资料以批次的方式包起来,每个训练的step就会拿到同样 batch size 的样本来训练。

Repeat:
其功能为要重覆这个 dataset 的元素几次,如果是count=2,那你可以做到在一个 epoch 内对每笔资料扫过两次的效果,当count=None时,即代表无限重覆,此时要注意你在 model.fit() 时,必须指定 steps_per_epoch,不然会永远算不玩一个 epoch 而错误!

介绍完这三支 API ,以下就来实验不同的组合之下,会有什麽样的效果啦!Dataset 很简单,就是1~13

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])

实验一:Shuffle -> Batch -> Repeat

BATCH_SIZE=4
SHUFFLE_SIZE=13

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)
ds = ds.repeat()

for example in ds.take(12):
  batch = example.numpy()
  print(batch)

产出:

[13 11  9  4]
[12 10  3  5]
[2 1 6 8]
[7]
[ 6  7  4 10]
[11  2  5 13]
[ 9  1 12  3]
[8]
[ 5  4  8 11]
[ 6 13  3 10]
[7 1 2 9]
[12]

数据有成功 shuffle,但因为 batch size=4 无法整除13,所以会有个 batch 只有一笔资料。

实验二:Repeat -> Shuffle -> Batch

BATCH_SIZE=4
SHUFFLE_SIZE=13

ds = dataset.repeat()
ds = ds.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)

for example in ds.take(12):
  batch = example.numpy()
  print(batch)
  

产出:

[11  9  6  3]
[1 1 4 5]
[ 2 12  2  3]
[10 13  7  8]
[2 1 4 3]
[ 9  4 11 10]
[ 7  5 10 13]
[ 8  2 13  6]
[11  4 12  1]
[10  6 12  5]
[6 7 8 8]
[3 5 9 7]

我们发现第二个batch 有两个重复的1,这是因为资料集是先被 repeat 後才开始 shuffle ,所以第一次 batch 拿走 [11, 9, 6, 3]後,有两个1都在 shuffle 的 buffer size 范围里,所以就有可能被同时拿出来,而也因为被 repeat 过,变成无限长的资料集,所以 batch 後不会像上面产生只有一笔资料的 batch。

实验三:Shuffle -> Repeat -> Batch

BATCH_SIZE=4
SHUFFLE_SIZE=13

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for example in ds.take(12):
  batch = example.numpy()
  print(batch)

产出:

[ 7 11  2  4]
[ 6  5  9 13]
[ 8  1 12 10]
[ 3  6  5 12]
[4 7 8 9]
[11 13  3  1]
[ 2 10  3  4]
[ 5 11  9  2]
[12  6  1  8]
[ 7 10 13 12]
[11 10  2  7]
[ 5  9 13  6]

这个顺序可以看到每个 batch 都是四个,而且每个数字在第二次被拿出来前,都有至少历经完整个资料集,相对实验二来说,每个资料集被拿到的机率平均了一些,这种组合是我自己比较常用的组合。

再来!我们要探讨 shuffle 的 buffer size 设置问题!

x = np.array(range(100))
x = x.repeat(10)
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)

首先,我先准备好资料集,有数字0~99,每个重复10遍,用 list 还看大致长这样:[ 0,0,0...1,1,1...2,2,2......99,99,99 ] 共1000个元素

现在,我们 buffer size 取 10,也就是故意让它只对前10个元素做 shuffle,看看会发生什麽事?

SHUFFLE_SIZE = 10

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for idx, example in enumerate(ds.take(100)):
  batch = example.numpy()
  print(batch)

产出:

[0 0 0 0]
[0 1 0 1]
[1 1 0 0]
[2 1 0 0]
[2 1 2 1]
[1 2 3 2]
[3 1 3 1]
[2 3 3 2]
(略)

没错!前面几个 step 拿到的数值都是很前面的元素,这样的 shuffle 效果不彰...

若我们把 buffer size 增大到100又会发生什麽事呢?

SHUFFLE_SIZE=100

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for idx, example in enumerate(ds.take(100)):
  batch = example.numpy()
  print(batch)

产出:

[3 4 5 0]
[4 1 4 9]
[2 6 3 7]
[10  8  4  3]
[2 0 3 7]
[5 0 4 0]
[ 8  0 11  7]
[11  6  0 12]
[11 10  6  2]
[13 11  9  6]
(略)

有比上一个实验有变化些,但是仍然没有拿到比较尾端90,91,92..等元素。因此问题来了,如果今天我的 dataset 数量庞大,我如果把 buffer size 设定的和 dataset 数量一样,结果遇上记忆体 OOM 问题而程序炸掉,那我应该怎麽办?

其实更好的做法还是在包成 tfrecord 之前就先把资料打乱,我这边以 numpy 简单为例,在我把它变成 tf.data.Dataset 前,先自行用 np.random.shuffle 做 shuffle 後再使用 API。

x = np.array(range(100))
x = x.repeat(10)
np.random.shuffle(x) # shuffle
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)
SHUFFLE_SIZE=100

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for idx, example in enumerate(ds.take(100)):
  batch = example.numpy()
  print(batch)

产出:

[19 12 88 32]
[13 43 76 91]
[85 24 63 58]
[48 52 44 82]
[82 58 46 26]
[24 20 85 63]

可以看到有拿到85, 91等较为後面的元素。

以上就是使用 shuffle, repeat, batch 这三支 API 需要注意的地方。


<<:  RISC V::中断与异常处理 -- 异常篇

>>:  Day 7 - 使用 AES-CBC 机制对 Message 内文进行加密

DAY29 欸你Git来Hub一下

昨天提到先将本机的档案列为版控,但是光在本机这样操作还是不太够,其他人要一起共同开发的时候,还是一样...

【在 iOS 开发路上的大小事-Day30】结语

哇,不知不觉就来到了铁人赛的最後一天了,从一开始不知道要写什麽内容,到慢慢想到要写什麽,再到最後终於...

【Day1】odoo的基础知识

也许你还没听过odoo,但身为开发人员当你认识odoo後,会有相见恨晚的感觉!! odoo,他可以是...

从 IT 技术面细说 Search Console 的 27 组数字 KPI (27) :SEO KPI 那个最有价值呢(上)?

在 Google Search Console 的 KPI 总表中,每一个数字背後是一大堆因子,最後...

[Day 29] 使用 Python Flask 架设 API 吧!

使用 Python Flask 架设 API 吧! 今日学习目标 API 观念讲解 什麽是 API?...