InceptionResNetV2
1.1 来源
1.2 架构
1.3 特性
训练过程
2.1 预训练模型
2.2 设置Callbacks
2.3 设置训练集
2.4 开始训练模型
2.5 储存模型与纪录学习曲线
模型训练结果
3.1 学习曲线
3.2 Accuracy与Loss
验证模型准确度
4.1 程序码
4.2 验证结果
InceptionResNetV2
1.1 来源
1.2 架构
完整架构:
Stem:与InceptionV4中Stem相同。
Inception-ResNet:
Reduction:缓慢地降低特徵图的尺寸,避免特徵信息流失。
Activation scaling
1.3 特性
训练过程:
2.1 预训练模型
# IMPORT MODULES
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.layers import Input, Dense, GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from keras.models import Model
from keras.applications import InceptionResNetV2
# -----------------------------1.客制化模型--------------------------------
# 载入keras模型(更换输出图片尺寸)
model = InceptionResNetV2(include_top=False,
weights='imagenet',
input_tensor=Input(shape=(80, 80, 3))
)
# 定义输出层
x = model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(800, activation='softmax')(x)
model = Model(inputs=model.input, outputs=predictions)
# 编译模型
model.compile(optimizer=Adam(lr=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
2.2 设置Callbacks
# ---------------------------2.设置callbacks----------------------------
# 设定earlystop条件
estop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=1)
# 设定模型储存条件
checkpoint = ModelCheckpoint('InceptionResNetV2_checkpoint_v2.h5', verbose=1,
monitor='val_loss', save_best_only=True,
mode='min')
# 设定lr降低条件(0.001 → 0.0005 → 0.00025 → 0.000125 → 0.0001)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
patience=5, mode='min', verbose=1,
min_lr=1e-5)
2.3 设置训练集
# -----------------------------3.设置资料集--------------------------------
# 设定ImageDataGenerator参数(路径、批量、图片尺寸)
train_dir = './workout/train/'
valid_dir = './workout/val/'
test_dir = './workout/test/'
batch_size = 64
target_size = (80, 80)
# 设定批量生成器
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.5,
fill_mode="nearest")
val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
# 读取资料集+批量生成器,产生每epoch训练样本
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=target_size,
batch_size=batch_size)
valid_generator = val_datagen.flow_from_directory(valid_dir,
target_size=target_size,
batch_size=batch_size)
test_generator = test_datagen.flow_from_directory(test_dir,
target_size=target_size,
batch_size=batch_size,
shuffle=False)
2.4 重新训练模型权重
# -----------------------------4.开始训练模型------------------------------
# 重新训练权重
history = model.fit_generator(train_generator,
epochs=50, verbose=1,
steps_per_epoch=train_generator.samples//batch_size,
validation_data=valid_generator,
validation_steps=valid_generator.samples//batch_size,
callbacks=[checkpoint, estop, reduce_lr])
2.5 储存模型与纪录学习曲线
# -----------------------5.储存模型、纪录学习曲线------------------------
# 储存模型
model.save('./InceptionResNetV2_retrained_v2.h5')
print('已储存InceptionResNetV2_retrained_v2.h5')
# 画出acc学习曲线
acc = history.history['accuracy']
epochs = range(1, len(acc) + 1)
val_acc = history.history['val_accuracy']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend(loc='lower right')
plt.grid()
# 储存acc学习曲线
plt.savefig('./acc.png')
plt.show()
# 画出loss学习曲线
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend(loc='upper right')
plt.grid()
# 储存loss学习曲线
plt.savefig('loss.png')
plt.show()
模型训练结果
3.1 训练epochs:30 epochs。
3.2 耗费时间:2小时48分47秒(10127秒)。
3.3 学习曲线
3.4 Accuary与Loss
验证准确度
4.1 程序码
# -------------------------6.验证模型准确度--------------------------
# 以vali资料夹验证模型准确度
test_loss, test_acc = model.evaluate_generator(test_generator,
steps=test_generator.samples//batch_size,
verbose=1)
print('test acc:', test_acc)
print('test loss:', test_loss)
4.2 验证结果
下一章,预计和大家分享模型验证方法,并比较五个预训练模型迁移学习的成果。
让我们继续看下去...
在 Fluentd Bit 中可以使用 read 或 socket 方式处理日志 read 用於读容...
虽然不能进行人与人的连结 但我们可以进行装置与装置的连结~(^ω^)人(^ω^) 爲什麽要让你连进我...
前言 当使用者输入资料时,若不小心输入跳脱字元 Escape Character,如 \n or \...
什麽是 Lifecycle Hook? 在开始介绍之前,先来了解一下何谓 生命周期 (Lifecyc...
如今,YouTube、TikTok、Dailymotion 等在线网站越来越受欢迎,并吸引了全球大量...