Day-23 Model 可以重复使用吗? 储存和读取 Model

  • 总算,我们已经会建立自己独一无二的神经网路了~但,你有没有发现一个问题,我们的该不会每次要使用模型之前,都要全部重头来一遍吧?今天小小的资料跟小小的模型我都能够接受这样操作,但是...如果今天是百万个神经元级别的,我总不能还是每次要使用之前都从头来一次吧?
  • 当然不用,我们回头思考一下,对於我们训练模型来说,最重要的东西是什麽?就是我们模型中的变数对吧?只要我们记得最後训练完取得的参数,其实就等於是我们训练完的模型结果了阿~那我们在使用上也就是使用这些参数在操作,因此我们有没有个办法去记录这些东西呢?只要我们能记录一个模型的状况,这样有几个好处,
    • 训练完的模型可以直接在需要时做读取使用
    • 训练过程中如果持续有做纪录,如果训练不小心中断,可以直接从中断的地方开始训练
    • 要从中间做调整也可以利用中间的资料开始训练
  • 简单来说,当我们成功取得储存读取心法,我们可以说我们就掌控了整个训练,对於整个流程都更加的灵活强大了,所以就让我们来聊聊如何对模型做储存和读取吧~

Save & Load Model

  • 模型的读写有分成两种方式,一种方式我们称为懒人法,另一种则是比较推荐正规的方式,我们会分别聊到,也会解释差异

Lazy Way

  • Pytorch 提供了一个偷懒的方式,就是把整个 Model 储存起来,那我们直接拿一个例子做举例
import torch
import torch.nn as nn

class ExampleModel(nn.Module):
    
    def __init__(self, input_size):
        super(ExampleModel, self).__init__()
        self.linear = nn.Linear(input_size, 1)
    
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        
        return y_pred
        

model = ExampleModel(input_size=6)
  • 我们在初始宣告一个 Model 的时候,其实就会有初始的参数了,因此我们在这个时候去输出我们的参数的话会变成这样
print('Before saveing: ')
for parm in model.parameters():
    print(parm)
    
# before save
# Parameter containing:
# tensor([[-0.2966, -0.2289, -0.3195, -0.2210, -0.2217,  0.1012]],
#        requires_grad=True)
# Parameter containing:
# tensor([0.1014], requires_grad=True)
  • 我们可以看到这个时候我们已经有初始的 weights 跟 bias 了,那依照一般的训练过程就是会拿这组参数去验证资料,然後看 loss 的状况等等一路往下做训练
  • 所以让我们把现在的模型状况储存起来,会用到的工具叫做 torch.save(arg, PATH),会需要两个参数,方别是我们的 model 跟要储存的位置 PATH,因此范例会长下面这样
# save whole model
FILE = 'model_all.pt'
torch.save(model, FILE)
  • 那我们如果要使用这个储存起来的模型,我们要怎麽去读取呢?这时会利用到另一个函式 torch.load(PATH),只要给储存的位置,就会自动处理读取,我们看范例
# load model
model = torch.load(FILE)
  • 那这边要注意一件事情,模型在读取进来时,我们如果要使用评估模式(确保固定的推理状况),或是训练模式(确保可以有完整的训练过程),需要宣告不同的 model 状态,也就 model.eval()(评估模式)、model.train()(训练模式)
  • 那我们今天要验证资料,因此用 model.eval() 的评估模式来做资料检查
model.eval()

print('whole model load')
for parm in model.parameters():
    print(parm)
    
# whole model load
# Parameter containing:
# tensor([[-0.2966, -0.2289, -0.3195, -0.2210, -0.2217,  0.1012]],
#        requires_grad=True)
# Parameter containing:
# tensor([0.1014], requires_grad=True)
  • 我们可以发现读取进来的资料跟保存时的状态一毛毛一样样,这就是我们希望的效果
  • 但是,将整个模型保存下来的方式是不被推荐的,详细原因可以参考官方文件Pytroch saveing and loading model,简单来说就是这样的做法是比较不稳定的,容易造成模型的损毁
  • 那让我们来看看被推荐的做法

Recommended way

  • 今天我们主要目标其实是当前模型的参数,模型本身的结构那些其实不是那麽重要,因为我们随时可以自己重新建立这个结构,因此 Pytorch 提供了一个函式叫做 state_dict
  • state_dict 是一个简单的Python字典对象,每个层映射到其参数张量。我们来看看范例
import torch
import torch.nn as nn

class ExampleModel(nn.Module):
    
    def __init__(self, input_size):
        super(ExampleModel, self).__init__()
        self.linear = nn.Linear(input_size, 1)
    
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        
        return y_pred
        

model = ExampleModel(input_size=6)
print('before save')
print(model.state_dict())

# before save
# OrderedDict([('linear.weight', tensor([[-0.0637,  0.2750, -0.3998,  0.2837, -0.2839,  0.3845]])), ('linear.bias', tensor([0.1257]))])
  • 那既然 state_dict 中已经储存了我们足够需要的资料了,那我们是不是可以只储存 state_dict() 的资料?当然可以,所以就让我们这麽做吧~
FILE = 'model_state_dict.pt'
model.save(model.state_dict(), FILE)
  • 那这边要注意,我们已经没有储存整个模型的结构状况了,因此在读取资料时,方式有点不同,我们首先还是要宣告 model,但我们要改用储存的 state_dict 参数代替原本初始的参数,来做剩下的行为,因此程序会变成这样
model = ExampleModel(input_size=6)
model.load_state_dict(torch.load(FILE))
  • 那这样就可以达到储存参数的效果了~

Checkpoint Design

  • 那我们有提过如果我们能适当的储存我们训练的过程作为记录点,会有助於我们在不管是
    • 中断还原
    • 更改训练状况
      等等其他训练的状况的使用
  • 因此如何建立好的 Checkpoint 是一个很好的问题,那这边我们就示范怎麽建立 Checkpoint(检查点)
  • 常见的 Checkpoint 会包含
    • epoch
    • model_state_dict
    • optimizer_state_dict
    • loss
    • ...
  • 还有其他东西,都可以视情况做添加,因此如果我们要储存这些资讯,我们要怎麽去储存和读取?
  • 储存
model = TheModelClass(*args, **kwargs)
loss = LossFunctionClass()
optimizer = TheOptimizerClass(*args, **kwargs)

# traning loop
for epoch in range(num_epochs):
    ...
    
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, PATH)
  • 读取
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

每日小结

  • 模型的状态储存是一个非常重要的议题,这决定了我们训练的模型状况可否被沿用,训练状况的保存挑整等等
  • Pytorch 提供了非常方便储存训练参数的方式,但是要记得里面仍然有些许限制,建议大家去官网好好看看,这边只是入门
  • 我们总算是把所有 PyTorch 的心法都说明清楚了~可喜可贺可喜可贺,明天让我们聊聊在深度学习领域都会遇到的好朋友 CNN 之後,就可以开始我们的手写辨识训练了~

<<:  Vuex实作

>>:  Day 24. Zabbix 通知设定 - Webhook - Telegram

【设计+切版30天实作】|Day3 - 参考Bootstrap画出理想的header(上集)

设计大纲 今天来设计Landing page的header。这次想要做的是一个满版的header,在...

Day19# Leetcode - Palindrome Number

今天是第 19 天,要来写的题目是 Palindrome Number 那麽话不多说,我们就开始吧 ...

元件服务--Windows的系统零件管理师

今天我想介绍最後一个警告事件,顺便谈谈「元件服务」这回事,他是一个Windows系统管理工具,管理C...

企划实现(16)

去背景功能 常常在做ui的同时会需要用到许多图片,但是网路上找到的图片往往都是有背景的,但是又不想额...

[Day23] 运用 VS Code 组合键加快编辑速度 - 操作介面篇

过去在看各个前端大神直播或影片的时候,都会发现他们有许多神奇又迅速的操作,但是又不知道该如何做到,今...