Xception
1.1 来源
1.2 架构
1.3 特性
训练过程
2.1 预训练模型
2.2 设置Callbacks
2.3 设置训练集
2.4 开始训练模型
2.5 储存模型与纪录学习曲线
模型训练结果
3.1 学习曲线
3.2 Accuracy与Loss
验证模型准确度
4.1 程序码
4.2 验证结果
Xception
1.1 来源:
1.2 架构
以改良後的Extreme Inception取代InceptionV3的Inception module。(对照图如下)
Extreme Inception(Xception)
Inception module(InceptionV3)
Extreme Inception引进Depthwise separable convolution概念降低网路的复杂度,同时拓宽网路,维持接近Inception module的参数量。
Standard convolution、Depthwise separable convolution与Extreme Inception。
Depthwise separable convolution(MobileNets):将传统卷积拆分成两个步骤,在维持准确度的前提下,降低参数量与模型训练时间。
Depthwise → (b)channel卷积运算; Pointwise → (c)1x1卷积运算combining
※ 详细请参考 论文中3.1Depthwise Separable Convolution
1.3 特性
观察InceptionV3、Xception模型参数量,Xception仅略少,训练时间稍微缩短。
Xception将空间相关性与通道相关性分离,更有效率的利用参数,模型准确度较高。
训练过程:
2.1 预训练模型
# IMPORT MODULES
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.layers import Input, Dense, GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from keras.models import Model
from keras.applications import Xception
# -----------------------------1.客制化模型--------------------------------
# 载入keras模型(更换输出图片尺寸)
model = Xception(include_top=False,
weights='imagenet',
input_tensor=Input(shape=(80, 80, 3))
)
# 定义输出层
x = model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(800, activation='softmax')(x)
model = Model(inputs=model.input, outputs=predictions)
# 编译模型
model.compile(optimizer=Adam(lr=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
2.2 设置Callbacks
# -----------------------------2.设置callbacks-----------------------------
# 设定earlystop条件
estop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=1)
# 设定模型储存条件
checkpoint = ModelCheckpoint('Xception_checkpoint_v2.h5', verbose=1,
monitor='val_loss', save_best_only=True,
mode='min')
# 设定lr降低条件(0.001 → 0.0005 → 0.000125 → 0.0001)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
patience=5, mode='min', verbose=1,
min_lr=1e-4)
2.3 设置训练集
# -----------------------------3.设置资料集--------------------------------
# 设定ImageDataGenerator参数(路径、批量、图片尺寸)
train_dir = './workout/train/'
valid_dir = './workout/val/'
test_dir = './workout/test/'
batch_size = 32
target_size = (80, 80)
# 设定批量生成器
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.5,
fill_mode="nearest")
val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
# 读取资料集+批量生成器,产生每epoch训练样本
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=target_size,
batch_size=batch_size)
valid_generator = val_datagen.flow_from_directory(valid_dir,
target_size=target_size,
batch_size=batch_size)
test_generator = test_datagen.flow_from_directory(test_dir,
target_size=target_size,
batch_size=batch_size,
shuffle=False)
2.4 重新训练模型权重
# -----------------------------4.开始训练模型------------------------------
# 重新训练权重
history = model.fit_generator(train_generator,
epochs=50, verbose=1,
steps_per_epoch=train_generator.samples//batch_size,
validation_data=valid_generator,
validation_steps=valid_generator.samples//batch_size,
callbacks=[checkpoint, estop, reduce_lr])
2.5 储存模型与纪录学习曲线
# -----------------------5.储存模型、纪录学习曲线------------------------
# 储存模型
model.save('./Xception_retrained_v2.h5')
print('已储存Xception_retrained_v2.h5')
# 画出acc学习曲线
acc = history.history['accuracy']
epochs = range(1, len(acc) + 1)
val_acc = history.history['val_accuracy']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend(loc='lower right')
plt.grid()
# 储存acc学习曲线
plt.savefig('./acc.png')
plt.show()
# 画出loss学习曲线
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend(loc='upper right')
plt.grid()
# 储存loss学习曲线
plt.savefig('loss.png')
plt.show()
模型训练结果
3.1 训练epochs:26 epochs。
3.2 耗费时间:3小时29分24秒(12564秒)。
3.3 学习曲线
3.4 Accuary与Loss
验证准确度
4.1 程序码
# -------------------------6.验证模型准确度--------------------------
# 以vali资料夹验证模型准确度
test_loss, test_acc = model.evaluate_generator(test_generator,
steps=test_generator.samples//batch_size,
verbose=1)
print('test acc:', test_acc)
print('test loss:', test_loss)
4.2 验证结果
下一章目标是:介绍第二个预训练模型ResNet152V2,与分享训练成果」。
让我们继续看下去...
<<: Day 14. UX/UI 设计流程之四: Wireflow,并以 Axure RP 实作 (上)
>>: Day 29-Unit Test 应用於使用重构与测试手法优化 C# Code-3 (情境及应用-9)
Dos 攻击 全称为 Denial-of-Service Attack 即「阻断服务攻击」, 亦被称...
题目简述: 一个由小到大排列的整数阵列,写一个函式回传每个元素的平方,并且也是由小到大排列 Inpu...
更改 state 有其风险,State manipulation 有赚有赔(?),更改前应详阅官方文...
前言 这篇有两个主题:公司文化与code review,而讲者特别强调的是要如何将这两件事情中间做...
day1来了 终於开始写第一天文章了,现在回头看,从被推坑到下定决心也是蛮曲折的!!! 这次友人推坑...