【第9天】训练模型-迁移学习

摘要

  1. 迁移学习说明
  2. 迁移学习类型
  3. 浅谈预训练与微调
  4. 如何进行迁移学习

内容

  1. 说明:基於资料集(ImageNet分类包括蛇、蜥蜴)、任务(皆为图片分类)相似性,将预训练的模型应用在新资料集的学习过程。通常用来解决以下问题:

    1.1 处理大量未标记资料(如:利用预训练的vgg16模型,辨识猫狗照片,并进行标签)

    1.2 降低大量资料的训练成本:大量或性质相似的资料集,以迁移学习提高训练效率。(节省时间、硬体资源)

    1.3 医疗应用需求(如:肾脏病变切片样本或有标记的样本稀少)

  2. 迁移学习类型

    2.1 基於实例:以ImageNet常见的1000种分类为例。

    • 1000种分类中有鸟、蜥蜴、猴子、蛇...等动物(domain)。
    • 若任务是辨识图片中是蛇与蜥蜴(task),可手动调整蛇和蜥蜴的权重。(依照经验调整)

    2.2 基於特徵:

    • 特徵萃取:若任务是分辨10种不同的蛇(task),先以预训练模型(domain)对10种蛇做特徵萃取,再将特徵喂入自己定义的神经网路训练。
    # 特徵萃取
    def feature_extraction_InV3(img_width, img_height,
                         train_data_dir,
                         num_image,
                         epochs):
        base_model = InceptionV3(input_shape=(299, 299, 3),
                              weights='imagenet', include_top=False)
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
    
        model = Model(inputs=base_model.input, outputs=x)
    
        train_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory(train_data_dir,
        target_size=(299, 299),
        # 每次进来训练并更新学习率的图片数 -> if 出现 memory leak -> 调低此参数
        batch_size=18,
        class_mode="categorical",
        shuffle=False)
    
        y_train=train_generator.classes
        # 依据class数量而定, np.zeros -> 宣告全部为0的空阵列
        y_train1 = np.zeros((num_image, 5))
        # np.arrange打标签
        y_train1[np.arange(num_image), y_train] = 1
    
        # 重设generator
        train_generator.reset
        X_train=model.predict_generator(train_generator, verbose=1)
        print(X_train.shape, y_train1.shape)
        return X_train, y_train1, model
    
    # 自定义全连接层
    def train_last_layer(img_width, img_height,
                         train_data_dir,
                         num_image,
                         epochs):
        # 处理train资料夹
        X_train, y_train, model=feature_extraction_InV3(img_width, img_height,
                             train_data_dir,
                             num_image,
                             epochs)
    
        # 处理test资料夹
        X_test,y_test,model=feature_extraction_InV3(img_width,img_height,
                             test_data_dir,
                             num_test_image,
                             epochs)
    
        my_model = Sequential()
        my_model.add(BatchNormalization(input_shape=X_train.shape[1:]))
        my_model.add(Dense(1024, activation="relu"))
        my_model.add(Dense(5, activation='softmax'))
        my_model.compile(optimizer="SGD", loss='categorical_crossentropy',metrics=['accuracy'])
        print(my_model.summary())
    
        history = my_model.fit(X_train, y_train, epochs=20,
                  validation_data=(X_test, y_test),
                  batch_size=30, verbose=1)
        my_model.save('model_CnnModelTrainWorkout_v3_5calsses.h5')
        return history
    

    2.3 基於模型:task与domain参数共享。如:预训练模型仅修改输出层(分2类),载入权重进行参数迁移学习。

    2.4 基於关系(参考许多文章与书籍,这个类别还是无法理解。欢迎有研究的夥伴,留言分享。)

  3. 预训练与微调

    3.1 预训练:训练模型时,从一开始的随机初始化参数,到随着训练调整参数,完成训练并储存参数的过程。

    3.2 微调:将预训练获得的参数,作为新资料集训练模型的初始参数,训练後获得适应新资料集的模型。

  4. 迁移学习过程

    4.1 挑选预训练模型:以tensorflow框架为例,可至Keras Application,挑选准确度高、轻量化的预训练模型,逐一训练、比较。

    4.2 选择训练方式

    • <资料集大,目标域(task)相似於来源域(domain)>

      载入模型结构,将预训练权重当作初始化参数,冻结底层卷积层,仅训练部分卷积层与顶端层。(部分冻结)

    • <资料集大,目标域(task)不同於来源域(domain)>

      载入模型结构,将预训练权重当作初始化参数,以新资料集全部重新训练。(不冻结)

    • <资料集小,目标域(task)相似於来源域(domain)>

      载入模型结构与权重,仅修改输出层(如:分1000类改成3类)或整个全连接层,再进行模型训练。

    • <资料集小,目标域(task)不同於来源域(domain)>

      以资料扩增增加训练样本,後续步骤同<资料集大,目标域(task)不同於来源域(domain)>。

    ※注:通常每个类别的资料数量小於1000笔,视为小资料集。


小结

  1. 资料前处理後,新资料集约有19.3万张。其中,每个中文字约有80-300张图档,且中文字不在1000个类别内。故属於「资料集小,目标域不同於来源域」。
  2. 下一章,目标是:「介绍Tensorflow Keras Application,并挑选预训练模型」。

让我们继续看下去...


参考资料

  1. 深度学习不得不会的迁移学习Transfer Learning
  2. 迁移学习(Transfer),面试看这些就够了!
  3. 第十一章 迁移学习

<<:  Day9-滚动视差(下)_後有图样

>>:  [Day 24] Node Event loop 3

Day 24: 人工智慧在音乐领域的应用 (AI作曲- RNN作曲)

循环神经网路(Recurrent neural network, RNN) 首先我们先来介绍循环神经...

Day3 什麽是Git?

大家好,我是乌木白,今天我们开始讲我们这次铁人赛的第一个技能,就是Git啦!先和大家声明我是把我自...

[Day 13] Sass - Maps

Maps 今天要来介绍的是在Sass内非常重要而且常用的一个功能 - Maps 之前有提到Maps是...

Day13 React- Forms(1)

小实作React Form 和Event的应用,使用useState Hook,让input输入的值...

Day20:安全性和演算法-杂凑函数(hash function)

安全性与演算法 在电脑科学的领域里,每一刻都有数以万计的资料在进行传输,在传输的过程中,是真的安全吗...