【第20天】训练模型-模型组合与辨识isnull(一)

摘要

  1. 作业流程
  2. 获得各模型800字机率表
  3. 安装R与RStudio

内容

  1. 作业流程(今日进度为1.1~1.2)

    1.1 获得各模型800字机率表。(包括官方800字内、官方800字外与测试赛资料集)

    1.2 安装R与RStudio。

    1.3 设定资料集路径

    1.4 找出每个中文字的阈值。(如何选择isnull的阈值)

    • <方法一>单一中文字机率最小值
    • <方法二>单一中文字机率平均值

    1.5 任意选择奇数个模型组合後,产生模型权重表与利用新模型权重得到的机率表。(如何选择加权依据)

    • <方法一>各模型中,单独一个字的辨识准确度占比(如:官方800字内的「白」有100个样本,正确辨识80个,「白」的辨识准确度就是80%)
    • <方法二>各模型中,单独计算800个中文字机率占比(sofmax输出机率)

    1.6 判断isnull。(如何选择判断isnull的依据)

    • <方法一>多模型投票,取多数者(故1.2选择奇数个模型随机组合)
    • <方法二>模型加权判断

    1.7 交叉验证不同方法组合的模型准确率。(共 2 * 2 * 2 = 8种)

  2. 获得各模型800字机率表

    2.1 预测样本

    • 官方800字内:train资料集(174,808 )+val资料集(18,717张),合计193,525张图档。
    • 官方800字外:723张图档。
    • 测试赛样本:90张图档。

    2.2 程序码

    from tensorflow.keras.preprocessing import image
    import numpy as np
    import os
    from tensorflow.keras.models import load_model
    import time
    import csv
    
    # 读取图档,并进行影像前处理
    def read_image(img_path):
        try:
            img = image.load_img(img_path, target_size=(80, 80))
        except Exception as e:
            print(img_path, e)
    
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        return img/255
    
    # 返回最大机率值的中文字的标签
    def to_word(pred:np.array)->str:
        index = np.argmax(pred)
        return labels[index]
    
    # 定义3个栏位值(预测值、实际值、是否正确预测)
    def append_pre_true(arr,pre,true,predict_true_or_not)->list:
        arr = list(arr)
        arr.append(pre)
        arr.append(true)
        arr.append(predict_true_or_not)
        return arr
    
    # 若预测结果正确,在predict_true_or_not填入1;反之团入0
    def predict_true_or_not(pre:str,true:str)->bool:
        if pre == true:
            return 1
        else:
            return 0
    
    # 定义栏位名称
    def columns(labels:list):
    
        for i in ['pre', 'label', 'true_or_not']:
            labels.append(i)
        return labels
    
    # 预测及赋值
    def model_predict(model,img):
        pred = model.predict(img)[0]
        word = to_word(pred)
        true_or_not = predict_true_or_not(word, subfolder)
        return pred, word, true_or_not
    
    if __name__ == "__main__":
        # 读取800字标签
        labels = os.listdir('./data/train')
        print(labels)
    
        # 载入训练好的模型
        model_path = './model/densenet201_v2/35_最佳/Densenet201_checkpoint_v2.h5'
        model = load_model(model_path)
        #存放档案的资料夹
        folder_name = './data/123/'
        #csv档名称
        csv_name = "Densenet201__retrained_v2_6K.csv"
    
        # 计时起点
        start = time.time()
    
        # 将800个字的预测机率,储存到CSV档
        with open(csv_name, "w", newline="", encoding="utf_8_sig") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(columns(labels))  #先写入栏位名称
            for subfolder in os.listdir(folder_name):
                for jpg in os.listdir(folder_name+subfolder):
                    # 请输入你的图片path
                    img_path = folder_name+subfolder+'/'+jpg
                    img = read_image(img_path)
                    #1*800向量,预测字,是否猜中
                    pred, word, true_or_not = model_predict(model,img)
                    #变成一列 : 1*800向量,预测字,真实字,是否猜中
                    row = append_pre_true(pred, word, subfolder, true_or_not)
                    writer.writerow(row)
                    print(word, subfolder, true_or_not)
    
        #计时终点
        end = time.time()
        spend = end - start
        hour = spend // 3600
        minu = (spend - 3600 * hour) // 60
        sec = spend - 3600 * hour - 60 * minu
        print(f'一共花费了{hour}小时{minu}分钟{sec}秒')
    

    2.3 输出800字机率表

  3. 安装R与RStudio

    3.1 R

    • 请到官方载点下载安装档。
    • 依照作业系统选择,此处以Windows10作业系统为例。(目前R-4.1.1-win)
    • 开启安装档後,连续点选下一步,完成安装。

    3.2 RStudio

    • 请到官方载点下载安装档。

    • 依照作业系统选择,此处以Windows10作业系统为例。(目前是RStudio-2021.09.0-351.exe)

    • 开启安装档後,连续点选下一步,完成安装。


小结

  1. 今天顺利获得每个模型输出800字机率表後,先安装R与RStudio,以备明日之用。
  2. 下一章,将和大家分享如何随机加权组合奇数个模型,再找出每个中文字的阈值,达成优化模型与判定isnull的目的。

让我们继续看下去...


<<:  Day 20 资料库评估 - Database Assessment (sqlmap, SQLite database browser)

>>:  Day_23 WireGuard

网路架构检视 - 网路分段/分区与 IP 发放

打 D2R ,连梗图都懒得找... 在资安法中,有些应办事项即使在技术面定义也很广,不会有明确的实作...

Flutter - Flutter 网路 GIF 图片重复播放

Flutter - Flutter 网路 GIF 图片重复播放 参考资料 Flutter开发实战系列...

Day2 网路是一堆电缆构成的,那网页呢?

大致了解网路是什麽之後,那每天逛的网页又是什麽呢? 什麽是网页? 网页是一份档案,通常会储存在服务器...

第 28 型 - 路由 (Router) - Resolve / 延迟载入 (Lazy Router)

上一篇利用路由机制传递参数来实作待办事项的编辑功能,除了透过网址路径来传递所需要的参数,还可以在路由...