今天的主题是要探讨优化器(Optimizer)对模型学习的影响,有关优化器该用哪个好,也是一个蛮令人头痛的问题,大部分的时候优化器都可以让你成功收敛,但有小部份时候优化器直接让你训练nan。
我们这次要比较的优化器从古早的SGD、Momentum、Adagrad、RMSProp、Adam,到较新的Range都有,要注意因为比较的优化器很多,很有可能会超出 Colab 使用时间上限,为了降低训练时间,我们会做迁移式学习,锁住模型142层以前的权重值,只专注训後面的几层作为观察。
另外有关各个优化器的介绍,我在之前有写过一篇介绍文可以看看。
实验一: SGD
全称 Stochastic gradient descent,即最基本的 gradient。
base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)
model = tf.keras.Model(inputs=[base.input], outputs=[net])
# Unfreeze weights
for idx, layer in enumerate(model.layers):
layer.trainable = FREEZE_INDEX < idx
model.compile(
optimizer=tf.keras.optimizers.SGD(LR),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
start = timeit.default_timer()
sgd_history = model.fit(
ds_train,
epochs=EPOCHS,
validation_data=ds_test,
verbose=True)
print(f'cost {timeit.default_timer()-start} sec')
产出:
loss: 0.0011 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4991 - val_sparse_categorical_accuracy: 0.8637
实验二:Momentum
在SGD中多加了动量的概念。
base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)
model = tf.keras.Model(inputs=[base.input], outputs=[net])
# Unfreeze weights
for idx, layer in enumerate(model.layers):
layer.trainable = FREEZE_INDEX < idx
model.compile(
optimizer=tf.keras.optimizers.SGD(LR, momentum=0.9),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
start = timeit.default_timer()
mom_history = model.fit(
ds_train,
epochs=EPOCHS,
validation_data=ds_test,
verbose=True)
print(f'cost {timeit.default_timer()-start} sec')
产出:
loss: 9.9336e-05 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.9496 - val_sparse_categorical_accuracy: 0.8206
实验三:Adagrad
在SGD多加了快取的概念
base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)
model = tf.keras.Model(inputs=[base.input], outputs=[net])
# Unfreeze weights
for idx, layer in enumerate(model.layers):
layer.trainable = FREEZE_INDEX < idx
model.compile(
optimizer=tf.keras.optimizers.Adagrad(LR, initial_accumulator_value=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
start = timeit.default_timer()
ada_history = model.fit(
ds_train,
epochs=EPOCHS,
validation_data=ds_test,
verbose=True)
print(f'cost {timeit.default_timer()-start} sec')
产出:
loss: 2.8722e-04 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.5482 - val_sparse_categorical_accuracy: 0.8686
实验四:RMSProp
在 Adagrad 中多加了 decay 的概念。这边由於我自己测试时,发现LR=0.1时,模型非常不稳定,所以此处LR改成0.001。
base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)
model = tf.keras.Model(inputs=[base.input], outputs=[net])
# Unfreeze weights
for idx, layer in enumerate(model.layers):
layer.trainable = FREEZE_INDEX < idx
model.compile(
optimizer=tf.keras.optimizers.RMSprop(0.001, rho=0.99),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
start = timeit.default_timer()
rms_history = model.fit(
ds_train,
epochs=EPOCHS,
validation_data=ds_test,
verbose=True)
print(f'cost {timeit.default_timer()-start} sec')
产出:
loss: 0.0232 - sparse_categorical_accuracy: 0.9951 - val_loss: 2.6411 - val_sparse_categorical_accuracy: 0.7304
图表上产生了有锯齿状的线,我认为应该是模型仍在多个 local minima 跳跃。
实验五:Adam
带入mean和var两个概念。同样发现LR=0.1时,模型不稳定,LR改成0.001。
base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)
model = tf.keras.Model(inputs=[base.input], outputs=[net])
# Unfreeze weights
for idx, layer in enumerate(model.layers):
layer.trainable = FREEZE_INDEX < idx
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001, beta_1=0.9, beta_2=0.999),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
start = timeit.default_timer()
adam_history = model.fit(
ds_train,
epochs=EPOCHS,
validation_data=ds_test,
verbose=True)
print(f'cost {timeit.default_timer()-start} sec')
产出:
loss: 7.5301e-05 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4853 - val_sparse_categorical_accuracy: 0.8706
第六个实验:Ranger
这个比较特别,这是一个结合RAdam和LookAhead(另外两个新型优化器)的优化器,原作Repo
只是这东西目前要使用的话,用 tensorflow addons 会比较方便。
!pip install -U tensorflow-addons
import tensorflow_addons as tfa
一样测试後,发现LR=0.001比较正常。
base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)
model = tf.keras.Model(inputs=[base.input], outputs=[net])
# Unfreeze weights
for idx, layer in enumerate(model.layers):
layer.trainable = FREEZE_INDEX < idx
radam = tfa.optimizers.RectifiedAdam(0.001)
ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)
model.compile(
optimizer=ranger,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
start = timeit.default_timer()
range_history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
verbose=True)
print(f'cost {timeit.default_timer()-start} sec')
产出:
loss: 4.8191e-04 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4846 - val_sparse_categorical_accuracy: 0.8696
以上就是我们针对六种不同的优化器训练同一个模型的实验,以我自己实务经验,我其实也是个跟风仔,会先尝试使用比较新型的优化器,但如果训练过程中发生 loss 不断增大的状况,我会再切成 SGD 来 debug 模型或调整 learning rate 来检查有没有问题。
延续上一篇谈了产业研究分析, 你若是看完DIGITIMES黄社长的短篇, 你一定会发觉他常常提到...
如需在地端环境操作 那需要去理解 什麽是node JS 什麽是NPM 需要参照 本地安装 使用 np...
宣告变数的资料型别--阵列 1.数值( Number ) 2.字串( String ) 3.布林值(...
创建migration迁移档案 首先先使用artisan指令: make:migration 创建一...
现在越来越多种类的装置出现,包括电脑、平板、手机,我们会在不同大小的萤幕上浏览网页,究竟网页要如何在...