【2】学习率大小的影响与学习率衰减(Learning rate decay)

Colab连结

大家应该听到烂了,学习率(Learning rate)指的是模型每做完一次 back propagation 後产生的 gradient 再乘上该值来对权重更新,而学习率越大,代表模型权重被更新的变化量也会跟着变大,而这个学习率该设定多少也是个麻烦的超参数,因此也有其他学者从其他面向如不同的优化器 (Optimizers) 来着手研究。但是今天我们比较单纯,我们都使用 SGD 作为优化器,但用不同的学习率来观察训练的结果。

这次我们使用我自己修改较为精简版的 alexnet 头开始训练,但因为怕 oxford_flowers102 过多的分类,导致模型可能需要非常多个 epochs 来跑,所以改用 tfds 提供的 cifar10 当参考,此资料集只有10个分类,训练和测试资料集分别是50000和10000张。

https://ithelp.ithome.com.tw/upload/images/20210916/20107299ofqwV5lkSw.png

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Flatten, Dense

def alexnet_modify():
  model = Sequential()
  model.add(Conv2D(32, (11, 11), padding='valid', input_shape=(227,227,3)))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(MaxPooling2D(pool_size=(3, 3)))

  model.add(Conv2D(64, (7, 7), padding='valid'))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(MaxPooling2D(pool_size=(3, 3)))

  model.add(Conv2D(96, (3, 3), padding='valid'))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(MaxPooling2D(pool_size=(3, 3)))

  model.add(Conv2D(64, (3, 3), padding='same'))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(MaxPooling2D(pool_size=(3, 3)))

  model.add(Flatten())
  model.add(Dense(128))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(Dense(64))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(Dense(NUM_OF_CLASS))

  return model

而原先 alexnet 的 input size 是227x227,但 cifar10 这个资料集的解析度都是32x32,所以要做 resize 的动作。

第一个实验,我们将学习率固定为0.1来训练15个 epochs。

LR = 0.1

model = alexnet_modify()

model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    verbose=True)
loss: 0.3053 - sparse_categorical_accuracy: 0.8906 - val_loss: 0.8681 - val_sparse_categorical_accuracy: 0.7557

https://ithelp.ithome.com.tw/upload/images/20210916/201072999D7WySrxJp.png

我们可以看到准确度的呈现为震荡向上

第二个实验将学习率缩小固定为0.001,一样15个epochs。

LR = 0.001

model = alexnet_modify()

model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    verbose=True)
loss: 0.8821 - sparse_categorical_accuracy: 0.6962 - val_loss: 0.9843 - val_sparse_categorical_accuracy: 0.6574

https://ithelp.ithome.com.tw/upload/images/20210916/20107299n48ZrZmWAW.png

得到的准确度有比较平稳的上升,但同样的 epoch 最後准确度却没有实验一来得高。

第三个实验,我们来实验学习率衰减的做法,简单来说,当模型一开始还是混乱状态时,较高的学习率有助於模型快速收敛,但是到了後期过高的学习率会导致模型不对的在各个局部最佳解中跳耀,而很难继续深入学习,所以我们使用 learning rate decay 这个策略来让学习率随着 epoch 数量增加来降低。

LR = 0.1

model = alexnet_modify()

model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

def scheduler(epoch):
  step = EPOCHS //3
  power = (epoch//step)+1

  new_lr = LR**(power)
  
  return new_lr

callback = tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1)

history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    callbacks=[callback],
    verbose=True)

我们降低的原则分成三个部分,前5个 epochs 我们学习率为0.1,中间5个 epochs 为0.01,最後5个 epochs 学习率降至0.001来实验。

loss: 0.3802 - sparse_categorical_accuracy: 0.8687 - val_loss: 0.6588 - val_sparse_categorical_accuracy: 0.7798

https://ithelp.ithome.com.tw/upload/images/20210916/20107299VekGVmmcwl.png

前期准确度稳定上升,但在第6个 epoch 进步趋缓,最终准确度来到77.9%

以上实验结论来看,使用 learning rate decay 可以让模型的训练稳定一些。


<<:  前言

>>:  Day 1. Pre-Start × 微前言

[day2] 付款流程 & 取得(Nonce)

资料准备 啊以为第二天开始就是程序码喔,NONONO,要接入金融机构的系统,不是任何人都能直接跑进去...

Day18 javascript 阵列

今天咱们先来简单的稍为介绍一下JavaScript Array(阵列) 物件,JavaScript ...

Angular 路由守卫(登入篇)

经过了昨天的介绍,今天就来看看使用登入范例罗 今天的登入资料依然是使用 FakeStoreAPI 登...

Day27 - 铁人付外挂测试验收(三)- 端对端测试

曾经做过一个专案,顾客把商品加入购物车後,可以同时选择要加入几笔商品,然後在结帐页的时候需要根据商品...

[Day4] Google Cloud

今天的内容会跟各位介绍 Google Cloud 相关的基础知识,希望不会不小心的讲成像业配文QQ,...