今天的主题比较特殊一些,要来探讨 tensorflow 中的 Dataset api : shuffle, batch 和 repeat 的顺序,在我一开始使用这个API时,我完全没想到他的顺序会完全影响到训练的结果而踩了好大的一坑。
Shuffle:
顾名思义,就是用来打乱资料集的API,只是需要注意的是在使用此 API 时,必须给予 buffer_size ,其用途是执行 shuffle 时,他并不是把全部的资料做 shuffle ,而是只会把前N个资料做 shuffle,这个N的数量就是 buffer_size 。
Batch:
前几天的实验已经用过了,就是将原本分散一笔一笔的资料以批次的方式包起来,每个训练的step就会拿到同样 batch size 的样本来训练。
Repeat:
其功能为要重覆这个 dataset 的元素几次,如果是count=2,那你可以做到在一个 epoch 内对每笔资料扫过两次的效果,当count=None时,即代表无限重覆,此时要注意你在 model.fit() 时,必须指定 steps_per_epoch,不然会永远算不玩一个 epoch 而错误!
介绍完这三支 API ,以下就来实验不同的组合之下,会有什麽样的效果啦!Dataset 很简单,就是1~13
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
实验一:Shuffle -> Batch -> Repeat
BATCH_SIZE=4
SHUFFLE_SIZE=13
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)
ds = ds.repeat()
for example in ds.take(12):
batch = example.numpy()
print(batch)
产出:
[13 11 9 4]
[12 10 3 5]
[2 1 6 8]
[7]
[ 6 7 4 10]
[11 2 5 13]
[ 9 1 12 3]
[8]
[ 5 4 8 11]
[ 6 13 3 10]
[7 1 2 9]
[12]
数据有成功 shuffle,但因为 batch size=4 无法整除13,所以会有个 batch 只有一笔资料。
实验二:Repeat -> Shuffle -> Batch
BATCH_SIZE=4
SHUFFLE_SIZE=13
ds = dataset.repeat()
ds = ds.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)
for example in ds.take(12):
batch = example.numpy()
print(batch)
产出:
[11 9 6 3]
[1 1 4 5]
[ 2 12 2 3]
[10 13 7 8]
[2 1 4 3]
[ 9 4 11 10]
[ 7 5 10 13]
[ 8 2 13 6]
[11 4 12 1]
[10 6 12 5]
[6 7 8 8]
[3 5 9 7]
我们发现第二个batch 有两个重复的1,这是因为资料集是先被 repeat 後才开始 shuffle ,所以第一次 batch 拿走 [11, 9, 6, 3]後,有两个1都在 shuffle 的 buffer size 范围里,所以就有可能被同时拿出来,而也因为被 repeat 过,变成无限长的资料集,所以 batch 後不会像上面产生只有一笔资料的 batch。
实验三:Shuffle -> Repeat -> Batch
BATCH_SIZE=4
SHUFFLE_SIZE=13
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for example in ds.take(12):
batch = example.numpy()
print(batch)
产出:
[ 7 11 2 4]
[ 6 5 9 13]
[ 8 1 12 10]
[ 3 6 5 12]
[4 7 8 9]
[11 13 3 1]
[ 2 10 3 4]
[ 5 11 9 2]
[12 6 1 8]
[ 7 10 13 12]
[11 10 2 7]
[ 5 9 13 6]
这个顺序可以看到每个 batch 都是四个,而且每个数字在第二次被拿出来前,都有至少历经完整个资料集,相对实验二来说,每个资料集被拿到的机率平均了一些,这种组合是我自己比较常用的组合。
再来!我们要探讨 shuffle 的 buffer size 设置问题!
x = np.array(range(100))
x = x.repeat(10)
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)
首先,我先准备好资料集,有数字0~99,每个重复10遍,用 list 还看大致长这样:[ 0,0,0...1,1,1...2,2,2......99,99,99 ] 共1000个元素
现在,我们 buffer size 取 10,也就是故意让它只对前10个元素做 shuffle,看看会发生什麽事?
SHUFFLE_SIZE = 10
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for idx, example in enumerate(ds.take(100)):
batch = example.numpy()
print(batch)
产出:
[0 0 0 0]
[0 1 0 1]
[1 1 0 0]
[2 1 0 0]
[2 1 2 1]
[1 2 3 2]
[3 1 3 1]
[2 3 3 2]
(略)
没错!前面几个 step 拿到的数值都是很前面的元素,这样的 shuffle 效果不彰...
若我们把 buffer size 增大到100又会发生什麽事呢?
SHUFFLE_SIZE=100
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for idx, example in enumerate(ds.take(100)):
batch = example.numpy()
print(batch)
产出:
[3 4 5 0]
[4 1 4 9]
[2 6 3 7]
[10 8 4 3]
[2 0 3 7]
[5 0 4 0]
[ 8 0 11 7]
[11 6 0 12]
[11 10 6 2]
[13 11 9 6]
(略)
有比上一个实验有变化些,但是仍然没有拿到比较尾端90,91,92..等元素。因此问题来了,如果今天我的 dataset 数量庞大,我如果把 buffer size 设定的和 dataset 数量一样,结果遇上记忆体 OOM 问题而程序炸掉,那我应该怎麽办?
其实更好的做法还是在包成 tfrecord 之前就先把资料打乱,我这边以 numpy 简单为例,在我把它变成 tf.data.Dataset 前,先自行用 np.random.shuffle 做 shuffle 後再使用 API。
x = np.array(range(100))
x = x.repeat(10)
np.random.shuffle(x) # shuffle
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)
SHUFFLE_SIZE=100
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for idx, example in enumerate(ds.take(100)):
batch = example.numpy()
print(batch)
产出:
[19 12 88 32]
[13 43 76 91]
[85 24 63 58]
[48 52 44 82]
[82 58 46 26]
[24 20 85 63]
可以看到有拿到85, 91等较为後面的元素。
以上就是使用 shuffle, repeat, batch 这三支 API 需要注意的地方。
>>: Day 7 - 使用 AES-CBC 机制对 Message 内文进行加密
昨天提到先将本机的档案列为版控,但是光在本机这样操作还是不太够,其他人要一起共同开发的时候,还是一样...
哇,不知不觉就来到了铁人赛的最後一天了,从一开始不知道要写什麽内容,到慢慢想到要写什麽,再到最後终於...
也许你还没听过odoo,但身为开发人员当你认识odoo後,会有相见恨晚的感觉!! odoo,他可以是...
在 Google Search Console 的 KPI 总表中,每一个数字背後是一大堆因子,最後...
使用 Python Flask 架设 API 吧! 今日学习目标 API 观念讲解 什麽是 API?...