Day 22 : 模型优化 - 知识蒸馏 Knowledge Distillation

什麽是知识蒸馏 Knowledge Distillation

  • 知识蒸馏 Knowledge Distillation 为模型压缩技术,其中 student 模型从可以更复杂的 teacher 模型中 "学习" 。换言之,如果已经透过复杂的结构建构出不错的模型,可以用知识蒸馏训练出较简易版本的模型,准确度不会差太多。
  • 知识蒸馏主要运用在分类任务上。
  • Colab 支援 ,参考Keras官方范例修改而成,理论请参见论文

实作知识蒸馏 Knowledge Distillation

  • 本范例皆以 tf.Kreas实作,过程包含:
    1. 自定义一个Distiller类别。
    2. 用 CNN 训练 teacher 模型。
    3. student 模型向 teacher 学习。
    4. 训练一个没向老师学的 student_scratch 模型进行比较。

准备资料

建立Distiller类别

  • 本篇使用 Keras 官方范例定义的 Distiller 类别。

  • 该类别继承於 th.keras.Model,并改写以下方法:

    • compile:这个模型需要一些额外的参数来编译,比如老师和学生的损失,alpha 和 temp 。
    • train_step:控制模型的训练方式。这将是真正的知识蒸馏逻辑所在。这个方法就是你做的时候调用的方法model.fit。
    • test_step:控制模型的评估。这个方法就是你做的时候调用的方法model.evaluate。
    class Distiller(keras.Model):
        def __init__(self, student, teacher):
            super(Distiller, self).__init__()
            self.teacher = teacher
            self.student = student
    
        def compile(
            self,
            optimizer,
            metrics,
            student_loss_fn,
            distillation_loss_fn,
            alpha=0.1,
            temperature=3,
            ):
            """ Configure the distiller.
            Args:
                optimizer: Keras optimizer for the student weights.
                metrics: Keras metrics for evaluation.
                student_loss_fn: Loss function of difference between student
                    predictions and ground-truth.
                distillation_loss_fn: Loss function of difference between soft
                    student predictions and soft teacher predictions.
                alpha: weight to student_loss_fn and 1-alpha to 
                    distillation_loss_fn.
                temperature: Temperature for softening probability 
                    distributions.
                    Larger temperature gives softer distributions.
            """
            super(Distiller, self).compile(
                optimizer=optimizer, 
                metrics=metrics
                )
            self.student_loss_fn = student_loss_fn
            self.distillation_loss_fn = distillation_loss_fn
            self.alpha = alpha
            self.temperature = temperature
    
        def train_step(self, data):
            # Unpack data
            x, y = data
    
            # Forward pass of teacher
            teacher_predictions = self.teacher(x, training=False)
    
            with tf.GradientTape() as tape:
                # Forward pass of student
                student_predictions = self.student(x, training=True)
    
                # Compute losses
                student_loss = self.student_loss_fn(y, student_predictions)
                distillation_loss = self.distillation_loss_fn(
                    tf.nn.softmax(
                        teacher_predictions / self.temperature, axis=1
                        ),
                    tf.nn.softmax(
                        student_predictions / self.temperature, axis=1
                        )
                    )
                loss = self.alpha * student_loss + (
                    1 - self.alpha) * distillation_loss
    
            # Compute gradients
            trainable_vars = self.student.trainable_variables
            gradients = tape.gradient(loss, trainable_vars)
    
            # Update weights
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    
            # Update the metrics configured in `compile()`.
            self.compiled_metrics.update_state(y, student_predictions)
    
            # Return a dict of performance
            results = {m.name: m.result() for m in self.metrics}
            results.update(
                {"student_loss": student_loss, 
                 "distillation_loss": distillation_loss}
            )
            return results
    
        def test_step(self, data):
            # Unpack the data
            x, y = data
    
            # Compute predictions
            y_prediction = self.student(x, training=False)
    
            # Calculate the loss
            student_loss = self.student_loss_fn(y, y_prediction)
    
            # Update the metrics.
            self.compiled_metrics.update_state(y, y_prediction)
    
            # Return a dict of performance
            results = {m.name: m.result() for m in self.metrics}
            results.update({"student_loss": student_loss})
            return results
    

建立老师与学生模型

  • 提醒2件事情:

    • 最後一层没有使用激励函数 softmax ,因为知识蒸馏需要原始的权重分布特徵,请记得去掉这层。
    • 通过 dropout 层的正则化将应用於教师而不是学生。这是因为学生应该能够通过蒸馏过程学习这种正则化。
  • 可以将学生模型视为教师模型的简化(或压缩)版本。

    def big_model_builder():
      keras = tf.keras
      model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(28, 28)),
        keras.layers.Reshape(target_shape=(28, 28, 1)),
        keras.layers.Conv2D(
            filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(
            filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(10)
      ])
      return model
    
    def small_model_builder():
      keras = tf.keras
      model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(28, 28)),
        keras.layers.Reshape(target_shape=(28, 28, 1)),
        keras.layers.Conv2D(
            filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(10)
      ])
      return model
    
    teacher = big_model_builder()
    student = small_model_builder()
    student_scratch = small_model_builder()
    

训练老师

  • 一如既往,毫无悬念的训练原始模型/老师模型。
    # Train teacher as usual
    teacher.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    teacher.summary()
    
    # Train and evaluate teacher on data.
    teacher.fit(train_images, train_labels, epochs=2)
    _ , ACCURACY['teacher model'] = teacher.evaluate(test_images, test_labels)
    

透过知识蒸馏训练学生

  • 创建Distiller类别的实例并传入学生和教师模型distiller = Distiller(student=student, teacher=teacher)。然後用合适的参数编译并训练。
    # Initialize and compile distiller
    distiller = Distiller(student=student, teacher=teacher)
    distiller.compile(
        optimizer=keras.optimizers.Adam(),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
        student_loss_fn=keras.losses.SparseCategoricalCrossentropy(
            from_logits=True),
        distillation_loss_fn=keras.losses.KLDivergence(),
        alpha=0.1,
        temperature=10,
    )
    
    # Distill teacher to student
    distiller.fit(
        train_images, 
        train_labels, 
        epochs=2, 
        shuffle=False
        )
    
    # Evaluate student on test dataset
    ACCURACY['distiller student model'], _ = distiller.evaluate(
        test_images, test_labels)
    
    

比较模型 - 从头训练学生

  • student_scratch 是个学生自己训练,未参与知识蒸馏过程的普通模型,架构与 student 相同,用来比较训练成果。
    # Train student as doen usually
    student_scratch.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
        )
    student_scratch.summary()
    
    # Train and evaluate student trained from scratch.
    student_scratch.fit(
        train_images, 
        train_labels, 
        epochs=2, 
        shuffle=False
        )
    # student_scratch.evaluate(x_test, y_test)
    _, ACCURACY['student from scrath model'] = student_scratch.evaluate(
        test_images, 
        test_labels
        )
    

比较模型准确率

  • 最终模型准确率约为:
    ACCURACY
    {'teacher model': 0.9822999835014343,
     'distiller student model': 0.9729999899864197,
     'student from scrath model': 0.9697999954223633}
    
  • 老师的准确率通常应该会高於学生,毕竟是倾注心力的模型。
  • 「接受知识蒸馏的学生」表现通常会优於「自己从头开始的学生」。
  • 学生的模型虽然较简易,知识蒸馏甚至会青出於蓝胜於蓝的情况,而且模型也较轻量。

小结

  • 在遇到巨型模型(如: GTP-3)时,运算资源恐怕不容许您轻易部署上线,此时采用知识蒸馏,让「学生」学习「老师」,至少比学生自主学习容易取得较佳结果。
  • 也因为 Keras 官方范例模型用 Colab 跑较久,故也自己改写较快收到成果的版本。
  • 连续谈自动化建模与模型优化,希望能让您将模型上线更有信心,当然如何监控与观察模型也相当重要,我们下篇见。
    /images/emoticon/emoticon41.gif

参考


<<:  Day 7 - Maximum Subarray

>>:  【Day 07】tuple 介绍!

JavaScript Day27 - IIFE (立即函式)

IIFE IIFE (立即函式):IIFE (Immediately Invoked Functio...

Proxmox VE 版本升级设定

先前我们提到,Proxmox VE 登入成功後会弹出一个「目前没有技术支援合约」的对话框,尽管它并...

如何让Word的数学公式居左,编号靠右

最近公司赞助我到Intelligent Automation & Soft Computin...

[Day-7] C++深入运算子

今天来深入研究运算子 有分为以下四种 ◆算术运算子 ◆关系运算子 ◆逻辑运算子 ◆位元运算子 特别在...

资安这条路 28 - [作业系统] Windows、Linux

Windows安全 帐号密码安全 目前主机上有哪些帐号 net user Windows 使用者帐号...