DAY9: 验证码辨识(二)

大家好,昨天我们把图片抓下来之後也标记完了(我个人是用了10000张图片),接下来就是丢进模型训练啦!这边小弟采用pytorch的框架,我们首先要写一个读取资料集的程序。

  • 图片整理

我们要将抓下来的10000张图片,分成train、test及validation,我是切7、2、1。
test可以边训练边观察有没有练起来及有没有过拟和,而validation则是让你测试模型准确度及泛化程度好不好。

  • 建立dataset

首先要import我们需要的套件。

  1. 我们在读取图片资料夹的时候会用到os.listdir将图片名称产生成一个list。
  2. PIL则是用在开启图片及转换RGB或者调整图片大小时可用到。
  3. torch.utils.data则是助於我们建立Dataset以供训练模型时读取。
  4. pandas则是我在导入图片的label时会需要用到。
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import pandas as pd

首先我们要先定义我们可能出现的字母及数字,总共有36个字。

alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'

另外用Image来写一个开启图片的程序。

def img_loader(img_path):
    img = Image.open(img_path)
    return img.convert("RGB")

接下来我们要取出我们的图片路径跟图片对应的label,我将注解跟程序码写在一起。

def make_dataset(data_path,ans_path,alphabet):
    img_names = os.listdir(data_path)# 取出图片名称
    img_names.sort(key=lambda x: int(x.split(".")[0]))# 让图片从小到大排序
    df_ans = pd.read_csv(ans_path)# 读取label的CSV档
    ans_list = list(df_ans["code"].values)# 取得label
    samples = []
    
    # 用zip将图片跟对应的答案凑一对
    for ans, img_name in zip(ans_list, img_names):
        if len(str(ans)) == 5  :#num_char:
            
            # 将图片名称及路径合并
            # 以便上述程序img_loader的执行
            img_path = os.path.join(data_path, img_name)
            target = []
            # 这边做5个字的辨识,例如:A5GG2
            # 会转换成target = [0,31,6,6,28]
            for char in str(ans):
                vec = [0] * 36 # num_class
                vec[alphabet.find(char)] = 1
                target += vec
                
            # 用samples把他们全部包起来    
            samples.append((img_path, target))
        else:
            print(img_name)
    return samples

接下来要把他从samples里面一一读取,并将label转换成tensor。

class CaptchaData(Dataset):
    def __init__(self, data_path,ans_path,
                 transform=None, target_transform=None, alphabet=alphabet):
        super(Dataset, self).__init__()
        self.data_path = data_path
        self.ans_path = ans_path
        # self.num_class = num_class
        # self.num_char = num_char
        self.transform = transform
        self.target_transform = target_transform
        self.alphabet = alphabet
        self.samples = make_dataset(self.data_path,self.ans_path,self.alphabet
                                    )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        img_path, target = self.samples[index]
        img = img_loader(img_path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, torch.Tensor(target)

下面这个程序是用来计算准确度的。

def calculat_acc(output, target):
    output, target = output.view(-1, 36), target.view(-1, 36)
    output = nn.functional.softmax(output, dim=1)
    output = torch.argmax(output, dim=1)
    target = torch.argmax(target, dim=1)
    output, target = output.view(-1, 5), target.view(-1, 5)
    correct_list = []
    for i, j in zip(target, output):
        if torch.equal(i, j):
            correct_list.append(1)
        else:
            correct_list.append(0)
    acc = sum(correct_list) / len(correct_list)
    return acc

接下来我们丢进model里面train,我这边选择预训练模型densenet201来做训练。

batch_size = 15 # 依照个人设备去设定
base_lr = 0.6   # 设定优化器一开始的学习率,都可以尝试看看
max_epoch = 60  # 看要练几个epoch
model_path = './tset_densenet.pth'
def train():
    transforms = Compose([ToTensor()])
    train_dataset = CaptchaData('./pic_train2',
                                './answer/answer_train_v2.csv',
                                transform=transforms)
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0,
                                   shuffle=True, drop_last=True)
    test_data = CaptchaData('./pic_test2',
                            './answer/answer_test_v2.csv',
                            transform=transforms)
    test_data_loader = DataLoader(test_data, batch_size=batch_size,
                                  num_workers=0, shuffle=True, drop_last=True)
                                  
    # 使用densenet201来做训练                         
    cnn = models.densenet201(num_classes=180)# 五个字,每个字有36种可能
    
    # 测试有没有装cuda,有没有GPU可以使用
    if torch.cuda.is_available():
        cnn.cuda()
    if restor:
        cnn.load_state_dict(torch.load(model_path))
    # 这边优化器我使用SGD+momentum,搭配CosineAnnealing(余弦退火)的学习率scheduler,可以不固定学习率,以防他停在局部低点,有机会找到全局最佳解。
    optimizer = torch.optim.SGD(cnn.parameters(), lr=base_lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=16, eta_min=0, last_epoch=-1, verbose=False)
    criterion = nn.MultiLabelSoftMarginLoss()

    for epoch in range(max_epoch):
        start_ = time.time()

        loss_history = []
        acc_history = []
        cnn.train()

        for img, target in train_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = cnn(img)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = calculat_acc(output, target)
            acc_history.append(float(acc))
            loss_history.append(float(loss))
        scheduler.step()
        print('train_loss: {:.4}|train_acc: {:.4}'.format(
            torch.mean(torch.Tensor(loss_history)),
            torch.mean(torch.Tensor(acc_history)),
        ))

        loss_history = []
        acc_history = []
        cnn.eval()
        for img, target in test_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = cnn(img)

            acc = calculat_acc(output, target)
            acc_history.append(float(acc))
            loss_history.append(float(loss))
        print('test_loss: {:.4}|test_acc: {:.4}'.format(
            torch.mean(torch.Tensor(loss_history)),
            torch.mean(torch.Tensor(acc_history)),
        ))
        print('epoch: {}|time: {:.4f}'.format(epoch, time.time() - start_))
        print("========================================")
        torch.save(cnn.state_dict(), model_path)

练了6个epoch,准确度到快95%了,这里就先用这个模型当作我们破解验证码的模型吧!!

  • 今日小结

今天就把code分享给大家,明天来用今天练的模型测试看看可不可以直接进入网页得到我们要的资讯吧!!!
明天见罗!!!


<<:  自然而然的敏捷导入

>>:  Scanners API-价格篇 && Pandas设定

Flutter基础介绍与实作-Day23 旅游笔记的实作(4)

今天就接续来讲中部地区的制作吧! 资料夹建立 lib/scareens/food_Middle/fo...

The field that fears with

In my career path, there is one type of job that I...

冒险村08 - Preitter output in rails console

08 - Preitter output in rails console Rails 的 defa...

企划实现(16)

去背景功能 常常在做ui的同时会需要用到许多图片,但是网路上找到的图片往往都是有背景的,但是又不想额...

【从零开始的 C 语言笔记】第二篇-大家的开始 - Hello World & 档案创建介绍

不怎麽重要的前言 上一篇我们成功的安装好一个程序码编辑器了,接下来我们要来学习怎麽使用它了! 写程序...