import torch
import torch.nn as nn
class ExampleModel(nn.Module):
def __init__(self, input_size):
super(ExampleModel, self).__init__()
self.linear = nn.Linear(input_size, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = ExampleModel(input_size=6)
print('Before saveing: ')
for parm in model.parameters():
print(parm)
# before save
# Parameter containing:
# tensor([[-0.2966, -0.2289, -0.3195, -0.2210, -0.2217, 0.1012]],
# requires_grad=True)
# Parameter containing:
# tensor([0.1014], requires_grad=True)
torch.save(arg, PATH)
,会需要两个参数,方别是我们的 model 跟要储存的位置 PATH,因此范例会长下面这样# save whole model
FILE = 'model_all.pt'
torch.save(model, FILE)
torch.load(PATH)
,只要给储存的位置,就会自动处理读取,我们看范例# load model
model = torch.load(FILE)
model.eval()
(评估模式)、model.train()
(训练模式)model.eval()
的评估模式来做资料检查model.eval()
print('whole model load')
for parm in model.parameters():
print(parm)
# whole model load
# Parameter containing:
# tensor([[-0.2966, -0.2289, -0.3195, -0.2210, -0.2217, 0.1012]],
# requires_grad=True)
# Parameter containing:
# tensor([0.1014], requires_grad=True)
state_dict
state_dict
是一个简单的Python字典对象,每个层映射到其参数张量。我们来看看范例import torch
import torch.nn as nn
class ExampleModel(nn.Module):
def __init__(self, input_size):
super(ExampleModel, self).__init__()
self.linear = nn.Linear(input_size, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = ExampleModel(input_size=6)
print('before save')
print(model.state_dict())
# before save
# OrderedDict([('linear.weight', tensor([[-0.0637, 0.2750, -0.3998, 0.2837, -0.2839, 0.3845]])), ('linear.bias', tensor([0.1257]))])
state_dict
中已经储存了我们足够需要的资料了,那我们是不是可以只储存 state_dict()
的资料?当然可以,所以就让我们这麽做吧~FILE = 'model_state_dict.pt'
model.save(model.state_dict(), FILE)
state_dict
参数代替原本初始的参数,来做剩下的行为,因此程序会变成这样model = ExampleModel(input_size=6)
model.load_state_dict(torch.load(FILE))
model = TheModelClass(*args, **kwargs)
loss = LossFunctionClass()
optimizer = TheOptimizerClass(*args, **kwargs)
# traning loop
for epoch in range(num_epochs):
...
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
>>: Day 24. Zabbix 通知设定 - Webhook - Telegram
设计大纲 今天来设计Landing page的header。这次想要做的是一个满版的header,在...
今天是第 19 天,要来写的题目是 Palindrome Number 那麽话不多说,我们就开始吧 ...
今天我想介绍最後一个警告事件,顺便谈谈「元件服务」这回事,他是一个Windows系统管理工具,管理C...
去背景功能 常常在做ui的同时会需要用到许多图片,但是网路上找到的图片往往都是有背景的,但是又不想额...
过去在看各个前端大神直播或影片的时候,都会发现他们有许多神奇又迅速的操作,但是又不知道该如何做到,今...