DAY26:判断800字外为isnull的方法

组合模型判断非800字内的字为isnull

  1. 获得各模型预测字的机率表

    • 800字内
    • 800字外
    • 测试赛资料集
  2. 资料总笔数

    • 官方800字内:train资料集(174,808 )+test资料集(18,717张),合计193,525张图档。
    • 官方800字外:723张图档。
    • 测试赛样本:90张图档。
  3. 程序码

import torch
import torch.nn as nn
import os
from dataset import CaptchaData
from torchvision.transforms import Compose, ToTensor
import csv
import copy


data_path = r"C:\Users\Frank\PycharmProjects\practice\mountain\清洗标签final\train_all"
img_names = os.listdir(data_path)
source = img_names
title = copy.deepcopy(source)
title.append('predict')
title.append('true')

f = open('./densenet201_in800_official_nomask.csv', 'a',newline='')
w = csv.writer(f)
w.writerow(title)
f.close()
alphabet = ''.join(source)

def predict(img_dir):
    n = 0
    m = 0
    transforms = Compose([ 
                           ToTensor()
                          ])
    dataset = CaptchaData(img_dir, transform=transforms)
    model = torch.load('./best_densenet201_8.pth')

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    for k, (img, target) in enumerate(dataset):
        img = img.view(1, 3 , 80 ,80 ).cuda()
        target = target.view(1, 1 * 800).cuda()
        output = model(img)

        output = output.view(-1, 800)
        target = target.view(-1, 800)
        output_prob = nn.functional.softmax(output, dim=1)
        output_prob_list = output_prob.cpu().detach().numpy().tolist()


        output = torch.argmax(output_prob, dim=1)

        target = torch.argmax(target, dim=1)
        output = output.view(-1, 1)[0]
        target = target.view(-1, 1)[0]


        print('pred: ' + ''.join([alphabet[i] for i in output.cpu().numpy()]))
        print('true: ' + ''.join([alphabet[i] for i in target.cpu().numpy()]))
        pred = ''.join([alphabet[i] for i in output.cpu().numpy()])
        true = ''.join([alphabet[i] for i in target.cpu().numpy()])
        if pred == true:
            n += 1
            output_prob_list[0].append(pred)
            output_prob_list[0].append(true)
            # output_prob_list[0].append(1)
            f = open('./densenet201_in800_official_nomask.csv', 'a',newline='')
            w = csv.writer(f)
            w.writerow(output_prob_list[0])
            f.close()
        else:
            m += 1
            output_prob_list[0].append(pred)
            output_prob_list[0].append(true)
            # output_prob_list[0].append(0)
            f = open('./densenet201_in800_official_nomask.csv', 'a',newline='')
            w = csv.writer(f)
            w.writerow(output_prob_list[0])
            f.close()

    print("pred_acc:", n / (n + m))
    print(m)
  1. 输出的800字机率表

    图片来源:https://ithelp.ithome.com.tw/articles/10277916

  2. 判断方法

    • 选择每个字的阈值

      • <方法一>单一中文字机率最小值
      • <方法二>单一中文字机率平均值
    • 任意选择奇数个模型组合後,产生模型权重表与利用新模型权重得到的机率表。

      • <方法一>各模型中,单独一个字的辨识准确度占比(如:官方800字内的「白」有100个样本,正确辨识80个,「白」的辨识准确度就是80%)
      • <方法二>各模型中,单独计算800个中文字机率占比(sofmax输出机率)
    • 如何判断isnull

      • <方法一>多模型投票,取多数者(故1.2选择奇数个模型随机组合)
      • <方法二>模型加权判断
  3. 之後的作法可参考训练模型-模型组合与辨识isnull(二)以及训练模型-模型组合与辨识isnull(三)


今日小结

  • 因为同组的关系,我的队友写得又比我快比我好,我忍不住要来分享一下他的文,我们做法是这麽做的,在最後会来探讨其他组别的作法。

  • 後面判断完isnull就剩下上GCP架设API,供比赛的时候使用。


<<:  AIS3 Pre-Exam + MFCTF

>>:  Day 26 长短期记忆网路 LSTM

【在 iOS 开发路上的大小事-Day13】Firebase 你好啊!

前情提要 Firebase 是 Google 推出的云端後端服务平台,提供了行动端 (Android...

【Day-28】我们是怎麽开始的?:一间传统软件公司从 0 开始建置的 DevOps 文化(工具篇)- 敏捷看板

前言 昨天我们稍微介绍了头脑风暴的作用与做法,今天我们稍微来介绍敏捷看板的用法! 敏捷看板是一款基础...

Day 02: ML基础第二步 Anaconda开发环境

前言 Python虽然可以直接使用Windows的Console直接执行程序,但是不只对於笔者,对於...

【JavaScript】阵列方法之filter()

【前言】 本系列为个人前端学习之路的学习笔记,在过往的学习过程中累积了很多笔记,如今想藉着IT邦帮忙...

Day9-滚动视差(下)_後有图样

今天继续说滚动视差 球球的部分先在scroll_thing的下方加上球球的div <div c...