[Day 11] 从零开始的 DenseNet 生活

0. 进度条

模型 进度
VGG Net (完成)
ResNet (完成)
DensNet (此篇)
MobileNet (未完成)
EfficientNet (未完成)

0.1 前言

大家都知道模型越深就越有潜力达到更高的准确率。
但是随着深度增加,梯度消失的问题就更严重,
ResNet的设计就是用来减缓这个问题,
而DenseNet就是在ResNet的基础上去做改良。


1. DenseNet 网路架构

01
DenseNet 全名为 Densely Connected Convolutional Network。
由许多个Dense Block组成,每个Block皆采用bottleneck结构。
而Dense Block用了许多「层与层的连结」来达到特徵重用性。

其论文开门见山地说:

Recent work has shown that convolutional networks can be
substantially deeper, more accurate, and efficient to train
if they contain shorter connections
between layers close to the input and those close to the output.

它有最多层与层的连结

这个shorter connections就是指ResNet中的直连通路,
ResNet只是在block的最顶端拉一条线出来,
做identity connection传到block的最底端。
但是DenseNet更狠!直接让每一层的输出传递到之後的每一层,
相当於是做了超多次identity connection。
假如我们有L层卷积神经网路,那就有https://chart.googleapis.com/chart?cht=tx&chl=L个(层与层之间的)连结。
但是DenseNet设计成有https://chart.googleapis.com/chart?cht=tx&chl=L%2F(L2%2B1)个连结。
如下图所示:
02
写成数学式子就长这样:
https://chart.googleapis.com/chart?cht=tx&chl=X_l%20%3D%20H(%5BX_0%2CX_1%2CX_2%2C...%2CX_(l-1)%5D)

要升维思考,不要同流合污

这个所谓「拉一条线出来做连结」其实就是在避免模型太「宽」。
我们可以发现当年的模型基本上都在追求又宽又深(又大又圆),像是GoogLeNet。
但是DenseNet不屑透过加深和加宽网路来增强图像表徵能力,
DenseNet利用"特徵重用性"来做到这件事。

Growth-rate

这个专有名词是指卷积层中卷积核的数量(k),
在DenseNet的实验中发现设k=12也能够获得和其他state-of-the-art一样的效果。
所以DenseNet是可以很「窄」(narrow)的。

1.1 DenseNet优点

  1. 更加减缓梯度消失问题:
    由於Densely connected,反向梯度传播十分容易,模型收敛效果佳。

  2. 特徵重用性:
    从低阶特徵到高阶特徵都会被直连到最後一层卷积层,
    这让下一层接受到更全面的图像资讯。

  3. 减少参数量:
    这很反直觉,这麽多连结却减少参数量?
    那是因为对於旧的特徵图(feature-map)是不需要再去重新学习的,
    而且因为growth-rate不用设很大,所以减少许多参数。

1.2 DenseNet缺点

  1. 对应优点1,反向传播虽然容易,但是计算复杂。
  2. 对应优点2和3,特徵重用这件事的坏处就是训练模型时RAM会爆炸。

2. 结语

看起来DenseNet稍微减缓了人们对於模型宽度的追求,
但不得不说,每个人在调整模型架构的时候都是先往加深加宽的方向。
而没有去考虑用更有效的层与层之间的连接方式。

我想我们应该要去思考:
有没有一种模型架构是可以容许我们用少少的深度和宽度,
就可以获得跟DensNet一样的效果呢?


3. 附录

因为Dense Block的结构很简单,
所以我练习用keras实现它,
如果我写错了...还请大家帮忙debug/images/emoticon/emoticon25.gif

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, BatchNormalization, ReLU
from tensorflow.keras.layers import Concatenate, Dropout
from tensorflow.keras.layers import Conv2D, AveragePooling2D, MaxPooling2D
from tensorflow.keras.layers import GlobalAveragePooling2D


def DenseLayer(x, growthRate, dropRate=0):

    # bottleneck
    x = BatchNormalization(axis=3)(x)
    x = ReLU(x)
    x = Conv2D(4*growthRate, kernel_size=(1, 1), padding='same')(x)

    # composition
    x = BatchNormalization(axis=3)(x)
    x = ReLU(x)
    x = Conv2D(growthRate, kernel_size=(3, 3), padding='same')(x)

    # dropout
    x = Dropout(dropRate)(x)

    return x


def DenseBlock(x, num_layers, growthRate, dropRate=0):

    for i in range(num_layers):
        featureMap = DenseLayer(x, growthRate, dropRate)
        x = Concatenate([x, featureMap], axis=3)

    return x


def TransitionLayer(x, ratio):

    growthRate = int(x.shape[-1]*ratio)
    x = BatchNormalization(axis=3)(x)
    x = ReLU(x)
    x = Conv2D(growthRate, kernel_size=(1, 1),
               strides=(2, 2), padding='same')(x)
    x = AveragePooling2D(pool_size=(2, 2), strides=(2, 2))(x)

    return x


def DenseNet121(numClass=1000, inputShape=(224, 224, 3), growthRate=12):
    x_in = Input(inputShape)
    x = Conv2D(growthRate*2, (3, 3), padding='same')(x_in)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)

    x = TransitionLayer(DenseBlock(
        x, num_layers=6, growthRate=12, dropRate=0.2))
    x = TransitionLayer(DenseBlock(
        x, num_layers=12, growthRate=12, dropRate=0.2))
    x = TransitionLayer(DenseBlock(
        x, num_layers=24, growthRate=12, dropRate=0.2))
        
    x = DenseBlock(x, num_layers=16, growthRate=12, dropRate=0.2)
    x = GlobalAveragePooling2D()(x)
    x_out = Dense(numClass=1000, activation='softmax')(x)

    model = Model(x_in, x_out)
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'])
    model.summary()
    return model

参考资料

  1. https://arxiv.org/abs/1608.06993

<<:  【第十二天 - Flutter NetWork 网路判断】

>>:  不容小觑的数据分析工具 - Excel:基础函数介绍

DAY 13:UML Class diagrams,在抽象世界的具现化宝石

在 DAY 1 ~ DAY 12 已经介绍了我认知常见的 concurrency patterns,...

【30天Lua重拾笔记32】进阶议题: LuaRocks & LuaDist

同步发表於个人网站 LuaRocks LuaRocks是类似npm、pip这样的套件管理工具,你可...

第 2 集:认识 Bootstrap 5 世界

此篇会分享 Bootstrap 5 环境设置,示范 VSCode、CodePen 两种不同环境的设置...

【从零开始的 C 语言笔记】第八篇-printf 介绍与应用

不怎麽重要的前言 上一篇我们介绍了与输入输出格式相关的语法,想必大家应该多少知道要怎麽使用了,如果有...

卡夫卡的藏书阁【Book11】- Kafka Connect 2

Step3. 新增 Source connector 可以查看一下当前的 connector $ c...