【20】从头自己建一个 keras 内建模型 (以 MobileNetV2 为例)

Colab连结

虽然 Tensorflow 提供了几个预训练模型让我们可以很快的完成训练任务,但是有时候想做一需实验时(比如说微调 mobilenet 的 CNN 层节点数 ),就没有简单易用的 API。因此今天需要把手弄脏,学习如何从0建构一个 mobilenetV2 出来,并且把 Tensorflow 提供的预训练权重也一并转移上去!

首先老样子,我们先产生官方版的 mobilenetV2,并把权重锁住。

base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base.trainable = False
net = tf.keras.layers.GlobalAveragePooling2D()(base.output)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)

model = tf.keras.Model(inputs=[base.input], outputs=[net])
model.summary()

接着会印出很长一串整个模型的 Layer 细节,可以让我们观摩怎麽从头开始建。

_____________________________________
Layer (type)                    Output Shape         Param #     Connected to    ===================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0    _____________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         input_1[0][0]  ____________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]  
(略)
_____________________________________
global_average_pooling2d (Globa (None, 1280)         0           out_relu[0][0]    _______________________________________
dense (Dense)                   (None, 2)            2562        global_average_pooling2d[0][0]   
=====================================
Total params: 2,260,546
Trainable params: 2,562
Non-trainable params: 2,257,984

接着,我们简单训练10个 epochs ,因为锁住权重,所以收敛很快。

model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history = model.fit(
    ds_train,
    epochs=EPOCHS,
    validation_data=ds_test,
    verbose=True)

产出:

loss: 0.1718 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.9364 - val_sparse_categorical_accuracy: 0.7833

接着到了我们要复刻模型的环节,经过刚刚的 summary ,我们发现 mobilenetV2 中其实有很多 block_{num}_* 的结构,这个区块就是所谓的 bottleneck 设计(Conv -> DepthwiseConv -> Conv)。

def get_mobilenetV2(shape):
    input_node = tf.keras.layers.Input(shape=shape)

    net = tf.keras.layers.Conv2D(32, 3, (2, 2), use_bias=False, padding='same')(input_node)
    net = tf.keras.layers.BatchNormalization()(net)
    net = tf.keras.layers.ReLU(max_value=6)(net)

    net = tf.keras.layers.DepthwiseConv2D(3, use_bias=False, padding='same')(net)
    net = tf.keras.layers.BatchNormalization()(net)
    net = tf.keras.layers.ReLU(max_value=6)(net)
    net = tf.keras.layers.Conv2D(16, 1, use_bias=False, padding='same')(net)
    net = tf.keras.layers.BatchNormalization()(net)

    net = bottleneck(net, 16, 24, (2, 2), shortcut=False, zero_pad=True)  # block_1
    net = bottleneck(net, 24, 24, (1, 1), shortcut=True)  # block_2

    net = bottleneck(net, 24, 32, (2, 2), shortcut=False, zero_pad=True)  # block_3
    net = bottleneck(net, 32, 32, (1, 1), shortcut=True)  # block_4
    net = bottleneck(net, 32, 32, (1, 1), shortcut=True)  # block_5

    net = bottleneck(net, 32, 64, (2, 2), shortcut=False, zero_pad=True)  # block_6
    net = bottleneck(net, 64, 64, (1, 1), shortcut=True)  # block_7
    net = bottleneck(net, 64, 64, (1, 1), shortcut=True)  # block_8
    net = bottleneck(net, 64, 64, (1, 1), shortcut=True)  # block_9

    net = bottleneck(net, 64, 96, (1, 1), shortcut=False)  # block_10
    net = bottleneck(net, 96, 96, (1, 1), shortcut=True)  # block_11
    net = bottleneck(net, 96, 96, (1, 1), shortcut=True)  # block_12

    net = bottleneck(net, 96, 160, (2, 2), shortcut=False, zero_pad=True)  # block_13
    net = bottleneck(net, 160, 160, (1, 1), shortcut=True)  # block_14
    net = bottleneck(net, 160, 160, (1, 1), shortcut=True)  # block_15

    net = bottleneck(net, 160, 320, (1, 1), shortcut=False)  # block_16

    net = tf.keras.layers.Conv2D(1280, 1, use_bias=False, padding='same')(net)
    net = tf.keras.layers.BatchNormalization()(net)
    net = tf.keras.layers.ReLU(max_value=6)(net)

    return input_node, net


def bottleneck(net, filters, out_ch, strides, shortcut=True, zero_pad=False):

    padding = 'valid' if zero_pad else 'same'
    shortcut_net = net

    net = tf.keras.layers.Conv2D(filters * 6, 1, use_bias=False, padding='same')(net)
    net = tf.keras.layers.BatchNormalization()(net)
    net = tf.keras.layers.ReLU(max_value=6)(net)
    if zero_pad:
        net = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)))(net)

    net = tf.keras.layers.DepthwiseConv2D(3, strides=strides, use_bias=False, padding=padding)(net)
    net = tf.keras.layers.BatchNormalization()(net)
    net = tf.keras.layers.ReLU(max_value=6)(net)

    net = tf.keras.layers.Conv2D(out_ch, 1, use_bias=False, padding='same')(net)
    net = tf.keras.layers.BatchNormalization()(net)

    if shortcut:
        net = tf.keras.layers.Add()([net, shortcut_net])

    return net

完成上述的结构後,我们在建立一个 rework_model 并把权重值从原先的 tf 汇到重建的 model 里。

input_node, net = get_mobilenetV2((224,224,3))
net = tf.keras.layers.GlobalAveragePooling2D()(net)
net = tf.keras.layers.Dense(NUM_OF_CLASS)(net)

rework_model = tf.keras.Model(inputs=[input_node], outputs=[net])

rework_model.compile(
    optimizer=tf.keras.optimizers.SGD(LR),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

for origin_layer, rework_layer in zip(model.layers, rework_model.layers):
  origin_layer.trainable = True
  rework_layer.set_weights(origin_layer.get_weights())

为了确定这个重建的 rework_model 的结构和权重值都和原先的一样,我们用 evaluate() 来比较两者的 loss 和 准确度是否都相同。

model.evaluate(ds_test, verbose=True)
rework_model.evaluate(ds_test, verbose=True)

产出:

32/32 [==============================] - 3s 80ms/step - loss: 0.9364 - sparse_categorical_accuracy: 0.7833
32/32 [==============================] - 4s 79ms/step - loss: 0.9364 - sparse_categorical_accuracy: 0.7833

看起来没问题,两个模型的 loss 和 sparse_categorical_accuracy都相同,我们真的从头复制了一个 mobilenetV2 !


<<:  【Day19】Git 版本控制 - 多人协作 GitHub Flow

>>:  中阶魔法 - 范围链 Scope Chain

Day10-119. Pascal's Triangle II

今日题目:119. Pascal's Triangle II Given an integer ro...

Day 15状态管理

为什麽需要状态管理? 在开发应用程序的初期,只需将状态反映在View上即可,但一旦功能变多,介面上的...

Day-25 ImageView

ImageView为显示图片, 但在图片显示前, 必须先了解如何插入图片: Step1:於资料夹选取...

【rails】新手如何建立 CRUD

在开始一个专案的时候,新手常常不知道从哪边开始 决定整理一下专案制作的流程 本文主要陈述思考的脉络,...

[影片]第28天:英雄指南-5. 新增应用内导航(3)

GitHub:https://github.com/dannypc1628/Angular-Tou...