【第14天】训练模型-Xception

摘要

  1. Xception

    1.1 来源
    1.2 架构
    1.3 特性

  2. 训练过程

    2.1 预训练模型
    2.2 设置Callbacks
    2.3 设置训练集
    2.4 开始训练模型
    2.5 储存模型与纪录学习曲线

  3. 模型训练结果

    3.1 学习曲线
    3.2 Accuracy与Loss

  4. 验证模型准确度

    4.1 程序码
    4.2 验证结果


内容

  1. Xception

    1.1 来源:

    • 简介:改良InceptionV3的Inception module,并引入depthwise separable convolution概念。
    • 时程:於2016年提出论文,并收录於2017年的CVPR。
    • 论文名称:Xception:Deep Learning with Depthwise Separable Convolutions

    1.2 架构

    • 以改良後的Extreme Inception取代InceptionV3的Inception module。(对照图如下)

      • Extreme Inception(Xception)

      • Inception module(InceptionV3)

    • Extreme Inception引进Depthwise separable convolution概念降低网路的复杂度,同时拓宽网路,维持接近Inception module的参数量。

    • Standard convolution、Depthwise separable convolution与Extreme Inception。

      • Standard convolution(InceptionV3)

      • Depthwise separable convolution(MobileNets):将传统卷积拆分成两个步骤,在维持准确度的前提下,降低参数量与模型训练时间。


        Depthwise → (b)channel卷积运算; Pointwise → (c)1x1卷积运算combining

      ※ 详细请参考 论文中3.1Depthwise Separable Convolution

      • Extreme Inception(Xception):类似Depthwise separable convolution,只是两者卷积运算的顺序相反。

    1.3 特性

    • 观察InceptionV3、Xception模型参数量,Xception仅略少,训练时间稍微缩短。

    • Xception将空间相关性与通道相关性分离,更有效率的利用参数,模型准确度较高。

  2. 训练过程:

    2.1 预训练模型

    # IMPORT MODULES
    from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
    from keras.layers import Input, Dense, GlobalAveragePooling2D
    from keras.preprocessing.image import ImageDataGenerator
    from keras.optimizers import Adam
    import matplotlib.pyplot as plt
    from keras.models import Model
    from keras.applications import Xception
    
    # -----------------------------1.客制化模型--------------------------------
    # 载入keras模型(更换输出图片尺寸)
    model = Xception(include_top=False,
                     weights='imagenet',
                     input_tensor=Input(shape=(80, 80, 3))
                     )
    
    # 定义输出层
    x = model.output
    x = GlobalAveragePooling2D()(x)
    predictions = Dense(800, activation='softmax')(x)
    model = Model(inputs=model.input, outputs=predictions)
    
    # 编译模型
    model.compile(optimizer=Adam(lr=0.001),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    

    2.2 设置Callbacks

    # -----------------------------2.设置callbacks-----------------------------
    # 设定earlystop条件
    estop = EarlyStopping(monitor='val_loss', patience=10, mode='min', verbose=1)
    
    # 设定模型储存条件
    checkpoint = ModelCheckpoint('Xception_checkpoint_v2.h5', verbose=1,
                              monitor='val_loss', save_best_only=True,
                              mode='min')
    
    # 设定lr降低条件(0.001 → 0.0005 → 0.000125 → 0.0001)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                               patience=5, mode='min', verbose=1,
                               min_lr=1e-4)
    

    2.3 设置训练集

    # -----------------------------3.设置资料集--------------------------------
    # 设定ImageDataGenerator参数(路径、批量、图片尺寸)
    train_dir = './workout/train/'
    valid_dir = './workout/val/'
    test_dir = './workout/test/'
    batch_size = 32
    target_size = (80, 80)
    
    # 设定批量生成器
    train_datagen = ImageDataGenerator(rescale=1./255, 
                                       rotation_range=20,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       shear_range=0.2, 
                                       zoom_range=0.5,
                                       fill_mode="nearest")
    
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    test_datagen = ImageDataGenerator(rescale=1./255)
    
    # 读取资料集+批量生成器,产生每epoch训练样本
    train_generator = train_datagen.flow_from_directory(train_dir,
                                          target_size=target_size,
                                          batch_size=batch_size)
    
    valid_generator = val_datagen.flow_from_directory(valid_dir,
                                          target_size=target_size,
                                          batch_size=batch_size)
    
    test_generator = test_datagen.flow_from_directory(test_dir,
                                          target_size=target_size,
                                          batch_size=batch_size,
                                          shuffle=False)
    

    2.4 重新训练模型权重

    # -----------------------------4.开始训练模型------------------------------
    # 重新训练权重
    history = model.fit_generator(train_generator,
                       epochs=50, verbose=1,
                       steps_per_epoch=train_generator.samples//batch_size,
                       validation_data=valid_generator,
                       validation_steps=valid_generator.samples//batch_size,
                       callbacks=[checkpoint, estop, reduce_lr])
    

    2.5 储存模型与纪录学习曲线

    # -----------------------5.储存模型、纪录学习曲线------------------------
    # 储存模型
    model.save('./Xception_retrained_v2.h5')
    print('已储存Xception_retrained_v2.h5')
    
    # 画出acc学习曲线
    acc = history.history['accuracy']
    epochs = range(1, len(acc) + 1)
    val_acc = history.history['val_accuracy']
    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'r', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend(loc='lower right')
    plt.grid()
    # 储存acc学习曲线
    plt.savefig('./acc.png')
    plt.show()
    
    # 画出loss学习曲线
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend(loc='upper right')
    plt.grid()
    # 储存loss学习曲线
    plt.savefig('loss.png')
    plt.show()
    
  3. 模型训练结果

    3.1 训练epochs:26 epochs。

    3.2 耗费时间:3小时29分24秒(12564秒)。

    3.3 学习曲线

    3.4 Accuary与Loss

  4. 验证准确度

    4.1 程序码

    # -------------------------6.验证模型准确度--------------------------
    # 以vali资料夹验证模型准确度
    test_loss, test_acc = model.evaluate_generator(test_generator,
                                steps=test_generator.samples//batch_size,
                                verbose=1)
    print('test acc:', test_acc)
    print('test loss:', test_loss)
    

    4.2 验证结果


小结

下一章目标是:介绍第二个预训练模型ResNet152V2,与分享训练成果」。

让我们继续看下去...


参考资料

  1. Xception: Deep Learning with Depthwise Separable Convolutions
  2. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications

<<:  Day 14. UX/UI 设计流程之四: Wireflow,并以 Axure RP 实作 (上)

>>:  Day 29-Unit Test 应用於使用重构与测试手法优化 C# Code-3 (情境及应用-9)

Day12【Web】网路攻击:DoS 与 DDoS

Dos 攻击 全称为 Denial-of-Service Attack 即「阻断服务攻击」, 亦被称...

Day 07 : Squares of a Sorted Array

题目简述: 一个由小到大排列的整数阵列,写一个函式回传每个元素的平方,并且也是由小到大排列 Inpu...

Day 18-更改 state 有其风险,State manipulation 有赚有赔(?),更改前应详阅官方文件说明书

更改 state 有其风险,State manipulation 有赚有赔(?),更改前应详阅官方文...

16. 从Code review体现公司文化

前言 这篇有两个主题:公司文化与code review,而讲者特别强调的是要如何将这两件事情中间做...

[day1]永丰Vue一下-从生活寻找灵感

day1来了 终於开始写第一天文章了,现在回头看,从被推坑到下定决心也是蛮曲折的!!! 这次友人推坑...