DenseNet201
1.1 来源
1.2 架构
1.3 特性
训练过程
2.1 预训练模型
2.2 设置Callbacks
2.3 设置训练集
2.4 开始训练模型
2.5 储存模型与纪录学习曲线
模型训练结果
3.1 学习曲线
3.2 Accuracy与Loss
验证模型准确度
4.1 程序码
4.2 验证结果
DenseNet201
1.1 来源
1.2 架构
ResNet的启发:
Inception的启发:
DenseNet:详细说明请参阅 论文Page3 3.DenseNets
1.3 特性
训练过程:
2.1 预训练模型
# IMPORT MODULES
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.layers import Input, Dense, GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from keras.models import Model
from keras.applications import DenseNet201
# -----------------------------1.客制化模型--------------------------------
# 载入keras模型(更换输出图片尺寸)
model = Densenet201(include_top=False,
weights='imagenet',
input_tensor=Input(shape=(80, 80, 3))
)
# 定义输出层
x = model.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(800, activation='softmax')(x)
model = Model(inputs=model.input, outputs=predictions)
# 编译模型
model.compile(optimizer=Adam(lr=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
2.2 设置Callbacks
# ---------------------------2.设置callbacks----------------------------
# 设定earlystop条件
estop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=1)
# 设定模型储存条件
checkpoint = ModelCheckpoint('Densenet201_checkpoint_v2.h5', verbose=1,
monitor='val_loss', save_best_only=True,
mode='min')
# 设定lr降低条件(0.001 → 0.0002 → 0.00004 → 0.00001)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
patience=5, mode='min', verbose=1,
min_lr=1e-4)
2.3 设置训练集
# -----------------------------3.设置资料集--------------------------------
# 设定ImageDataGenerator参数(路径、批量、图片尺寸)
train_dir = './workout/train/'
valid_dir = './workout/val/'
test_dir = './workout/test/'
batch_size = 64
target_size = (80, 80)
# 设定批量生成器
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.5,
fill_mode="nearest")
val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
# 读取资料集+批量生成器,产生每epoch训练样本
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=target_size,
batch_size=batch_size)
valid_generator = val_datagen.flow_from_directory(valid_dir,
target_size=target_size,
batch_size=batch_size)
test_generator = test_datagen.flow_from_directory(test_dir,
target_size=target_size,
batch_size=batch_size,
shuffle=False)
2.4 重新训练模型权重
# -----------------------------4.开始训练模型------------------------------
# 重新训练权重
history = model.fit_generator(train_generator,
epochs=50, verbose=1,
steps_per_epoch=train_generator.samples//batch_size,
validation_data=valid_generator,
validation_steps=valid_generator.samples//batch_size,
callbacks=[checkpoint, estop, reduce_lr])
2.5 储存模型与纪录学习曲线
# -----------------------5.储存模型、纪录学习曲线------------------------
# 储存模型
model.save('./Densenet201_retrained_v2.h5')
print('已储存Densenet201_retrained_v2.h5')
# 画出acc学习曲线
acc = history.history['accuracy']
epochs = range(1, len(acc) + 1)
val_acc = history.history['val_accuracy']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend(loc='lower right')
plt.grid()
# 储存acc学习曲线
plt.savefig('./acc.png')
plt.show()
# 画出loss学习曲线
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend(loc='upper right')
plt.grid()
# 储存loss学习曲线
plt.savefig('loss.png')
plt.show()
模型训练结果
3.1 训练epochs:25 epochs。
3.2 耗费时间:2小时22分23秒(8543秒)。
3.3 学习曲线
3.4 Accuary与Loss
验证准确度
4.1 程序码
# -------------------------6.验证模型准确度--------------------------
# 以vali资料夹验证模型准确度
test_loss, test_acc = model.evaluate_generator(test_generator,
steps=test_generator.samples//batch_size,
verbose=1)
print('test acc:', test_acc)
print('test loss:', test_loss)
4.2 验证结果
<<: 如何使用cython来打包程序码成pyd格式 (就是DLL档的意思)
>>: Day 16 - Asynchronous 非同步进化顺序 - Async/Await
前言 Python 是一种直译式语言,近几年在资料科学中 (例如:人工智慧、大数据分析 等) 有着耀...
一、变数 JavaScript 七种型态 Primitive type null undefine ...
这一个章节节我们要来介绍复合查询,当单一的查询子句无法完成需求时,为了应付这种高级查询需求,所以就产...
Coroutine 中如果要执行非同步程序,则需要把耗时任务写在 suspend 函式中,并且在一个...
-侧信道攻击 侧信道攻击(Side-channel attack) 只需在设备或系统附近放置天线、...