DAY25:模型训练DenseNet201

DenseNet201

  1. 简介

    • DenseNet继承了ResNet的短路连线机制,并调整为密集连接机制。密集连线比传统的网路有更少的参数,因为不需要重新学习多余的特徵图。另外密集连线甚至有正则化的作用,可以减少过拟合的发生机率。

      ResNet短路连线

      图片来源:https://codingnote.cc/zh-hk/p/153860/

      DenseNet密集连线

      图片来源:https://codingnote.cc/zh-hk/p/153860/

    • 解决容易梯度消失的问题,增强特徵的传播,使特徵重复利用,减少参数的数量。

    • 与ResNet不同之处,DenseNet是将其特徵进行并接(concatenate)方式输入进下一层,而不是用ResNet的特徵相加(summation)。

      图片来源:https://arxiv.org/pdf/1608.06993.pdf

    • 於2016年提出的Dense Block,以前馈方式(feed-forward)将每层连接到每个其他层。而具有L层的传统卷积网络具有L个连接,而每个层与其後一个层之间,又有 L(L + 1)/2 个直接连接。对於每一层,前面层的所有输出,都成为後面层的输入。

     

    图片来源:https://medium.com/%E5%AD%B8%E4%BB%A5%E5%BB%A3%E6%89%8D/dense-cnn-%E5%AD%B8%E7%BF%92%E5%BF%83%E5%BE%97-%E6%8C%81%E7%BA%8C%E6%9B%B4%E6%96%B0-8cd8c65a6f3f

    • DenseNet特性总结:
         
      • 透过Dense Block可以提高特徵图的利用效率,减少参数,降低梯度消失的发生机率。

      • 密集连线不需重新学习新的特徵图,每次的input都含有之前层的资讯。

      • 是个轻量型,准确度又不错的模型。

      • 在ImageNet上,DenseNet在保有准确率的情况下,模型的效能甚至超出VGG NET与ResNet。


训练过程

  1. import 套件

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    from dataset import CaptchaData
    from torch.utils.data import DataLoader
    from torchvision.transforms import Compose, ToTensor,ColorJitter,RandomRotation,RandomAffine,Resize,Normalize,CenterCrop,RandomApply,RandomErasing
    import torchvision.models as models
    import time
    import copy
    
  2. dataset载入以及DataLoader

    train_dataset = CaptchaData('./mask_2/train',
                                 transform=transforms)
     train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0,
                                    shuffle=True, drop_last=True,pin_memory=True)
     test_data = CaptchaData('./mask_2/test',
                             transform=transforms_1)
     test_data_loader = DataLoader(test_data, batch_size=batch_size,
                                   num_workers=0, shuffle=True, drop_last=True,pin_memory=True)
    
    
  3. transforms的设置

    • train资料集设置有旋转、图像变换的transforms,而test我们则是设置只有转换成tensor及标准化。
    transform_set = [ RandomRotation(degrees=10,fill=(255, 255, 255)),
    RandomAffine(degrees=(-10,+10), translate=(0.2, 0.2), fillcolor=(255, 255, 255)),
    RandomAffine(degrees=(-10,+10),scale=(0.8, 0.8),fillcolor=(255, 255, 255)),
    RandomAffine(degrees=(-10,+10),shear=(0, 0, 0, 20),fillcolor=(255, 255, 255))]
    
    transforms = Compose([RandomApply(transform_set, p=0.7),
                           ToTensor(),
                            Normalize((0.5,), (0.5,))
                           ])
    
    transforms_1 = Compose([
                             ToTensor(),
                             Normalize((0.5,), (0.5,))
                             ])
    
  4. 计算准确度

    def calculat_acc(output, target):
     output, target = output.view(-1, 800), target.view(-1, 800)
     output = nn.functional.softmax(output, dim=1)
     output = torch.argmax(output, dim=1)
     target = torch.argmax(target, dim=1)
     output, target = output.view(-1, 1), target.view(-1, 1)
     correct_list = []
     for i, j in zip(target, output):
         if torch.equal(i, j):
             correct_list.append(1)
         else:
             correct_list.append(0)
     acc = sum(correct_list) / len(correct_list)
     return acc
    
  5. 预训练模型

    model = models.densenet201(num_classes=800)
    
  6. 储存best_model(test_score最高的模型)

    if epoch > min_epoch and acc_best <= acc:
        acc_best = acc
        best_model = copy.deepcopy(model)
    
  7. 完整的code

import torch
import torch.nn as nn
from torch.autograd import Variable
from dataset import CaptchaData
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor,ColorJitter,RandomRotation,RandomAffine,Resize,Normalize,CenterCrop,RandomApply,RandomErasing
import torchvision.models as models
import time
import copy
import matplotlib.pyplot as plt
batch_size = 32
max_epoch = 40
model_path = './densenet201_mask.pth'
restor = False



def calculat_acc(output, target):
    output, target = output.view(-1, 800), target.view(-1, 800)
    output = nn.functional.softmax(output, dim=1)
    output = torch.argmax(output, dim=1)
    target = torch.argmax(target, dim=1)
    output, target = output.view(-1, 1), target.view(-1, 1)
    correct_list = []
    for i, j in zip(target, output):
        if torch.equal(i, j):
            correct_list.append(1)
        else:
            correct_list.append(0)
    acc = sum(correct_list) / len(correct_list)
    return acc


def train():
    acc_best = 0
    best_model = None
    min_epoch = 1

    transform_set = [ RandomRotation(degrees=10,fill=(255, 255, 255)),
                      RandomAffine(degrees=(-10,+10), translate=(0.2, 0.2), fillcolor=(255, 255, 255)),
                      RandomAffine(degrees=(-10,+10),scale=(0.8, 0.8),fillcolor=(255, 255, 255)),
                      RandomAffine(degrees=(-10,+10),shear=(0, 0, 0, 20),fillcolor=(255, 255, 255))
]
    transforms = Compose([ ToTensor(),
                           RandomApply(transform_set, p=0.7),
                           Normalize((0.5,), (0.5,))
                          ])

    transforms_1 = Compose([
                            ToTensor(),
                            # Normalize((0.5,), (0.5,))
                            ])

    train_dataset = CaptchaData(r'C:\Users\Frank\PycharmProjects\practice\mountain\清洗标签final\train_nomask',
                                transform=transforms_1)
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0,
                                   shuffle=True, drop_last=True,pin_memory=True)
    test_data = CaptchaData(r'C:\Users\Frank\PycharmProjects\practice\mountain\清洗标签final\test_nomask',
                            transform=transforms_1)
    test_data_loader = DataLoader(test_data, batch_size=batch_size,
                                  num_workers=0, shuffle=True, drop_last=True,pin_memory=True)
    print('load.........................')

    model = models.densenet201(num_classes=800)

    if torch.cuda.is_available():
        model.cuda()
    if restor:
        model.load_state_dict(torch.load(model_path))
    # optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max =8 , eta_min=0, last_epoch=-1, verbose=False)
    criterion = nn.CrossEntropyLoss()
    acc_history_train = []
    loss_history_train = []
    loss_history_test = []
    acc_history_test = []
    for epoch in range(max_epoch):
        start_ = time.time()

        loss_history = []
        acc_history = []
        model.train()

        for img, target in train_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            target = torch.tensor(target, dtype=torch.long)
            output = model(img)

            loss = criterion(output, torch.max(target,1)[1])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = calculat_acc(output, target)
            acc_history.append(float(acc))
            loss_history.append(float(loss))
        scheduler.step()
        print('train_loss: {:.4}|train_acc: {:.4}'.format(
            torch.mean(torch.Tensor(loss_history)),
            torch.mean(torch.Tensor(acc_history)),
        ))
        acc_history_train.append((torch.mean(torch.Tensor(acc_history))).float())
        loss_history_train.append((torch.mean(torch.Tensor(loss_history))).float())
        loss_history = []
        acc_history = []
        model.eval()
        for img, target in test_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = model(img)



            acc = calculat_acc(output, target)
            if epoch > min_epoch and acc_best <= acc:
                acc_best = acc
                best_model = copy.deepcopy(model)
            acc_history.append(float(acc))
            loss_history.append(float(loss))
        print('test_loss: {:.4}|test_acc: {:.4}'.format(
            torch.mean(torch.Tensor(loss_history)),
            torch.mean(torch.Tensor(acc_history)),
        ))
        acc_history_test.append((torch.mean(torch.Tensor(acc_history))).float())
        loss_history_test.append((torch.mean(torch.Tensor(loss_history))).float())
        print('epoch: {}|time: {:.4f}'.format(epoch, time.time() - start_))
        print("==============================================")
        torch.save(model.state_dict(), model_path)
        modelbest = best_model
        torch.save(modelbest, './densenet201_mask2.pth')
    # 画出acc学习曲线
    acc = acc_history_train
    epoches = range(1, len(acc) + 1)
    val_acc = acc_history_test
    plt.plot(epoches, acc, 'b', label='Training acc')
    plt.plot(epoches, val_acc, 'r', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend(loc='lower right')
    plt.grid()
    # 储存acc学习曲线
    plt.savefig('./acc_densenet201.png')
    plt.show()

    # 画出loss学习曲线
    loss = loss_history_train
    val_loss = loss_history_test
    plt.plot(epoches, loss, 'b', label='Training loss')
    plt.plot(epoches, val_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend(loc='upper right')
    plt.grid()
    # 储存loss学习曲线
    plt.savefig('./loss_densenet201.png')
    plt.show()
if __name__ == "__main__":
    train()
    pass 

训练结果

  1. 学习曲线

  2. 准确度

  3. 总结

    • 训练epoch:20 epoches
    • 训练总时数:1小时55分钟
    • callback采纪录最高test_score
    • test_score:95.03 %
    • 比ResNet的准确度高,且收敛速度较快,效果较好。

<<:  追求JS小姊姊系列 Day25 -- 工具人、姐妹的存活原理:宣告变数的有效区域

>>:  Day 25 Redux 介绍

Day09 Platform Channel - BasicMessageChannel

如同前面介绍的,Flutter 定义了三种不同型别的Platform Channel 在platfo...

AWS Academy LMS 申请开课 - 教师

AWS 所提供的 AWS Academy 教材可以透过 AWS Academy Learning M...

咱研究出新的类阵列资料结构的说

嗨咪纳桑,咱是immortalmice,今天要来和各位分享自己研究出的几个新资料结构 这个资料结构支...

Day23 X WebAssembly

也许你早就听过 WebAssembly 这个词,传说中它可以让 C, C++, Rust 等系统语...

Day 27:DB也是假的 建立Mock SQLDelight

Keyword: SQLDelight Mock Test 直到27日,完成KMM的测试功能放在 K...