上一章节研究完class net(…),这一章节我们继续研究 def main(args)这部分。
def main(args):
# define data directory and device (CPU or GPU)
use_cuda = not args['no_cuda'] and torch.cuda.is_available()
torch.manual_seed(args['seed'])
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = args['data_dir']
# --- data loader
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)
# --- end data loader
hidden_size = args['hidden_size']
model = Net(hidden_size=hidden_size).to(device)
optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])
for epoch in range(1, args['epochs'] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')
一开始的部分,主要是定义 data directory and device for CPU or GPU。
接下来我们看一下 data loader的部分。 data loader可以帮我们整理转换资料外(ToTensor()),还可以依照我们的需要,一批批的吐出来。例如train_loader里的 batch_size。说明一下,这里的test_loader,是训练模型时,用来做validation用的。个人比较喜欢用valid_loader一词。
资料有时候很分散,此时会影响训练及验证速度(计算速度、判别速度),所以需要正规化。 transforms.Normalize((0.1307,), (0.3081,)),由於资料是黑白的,RGB channel 只有1个,所以只有1个数字。Mean=(0.1307,),Std= (0.3081,)。至於里面的数字为何为此,我也不知。也许可以从data.describe()的统计表中得知一二吧!
另外,shuffle=True,主要是让每批资料,都能很平均的取样,以免资料取样产生偏颇,让模型导致无效!
初始化model为神经网路 class Net时,要设定其隐藏层的大小,hidden_size。即
model = Net(hidden_size=hidden_size).to(device)。
请注意,model 是放在device中执行,所以 training and validation 时,loader过来的资料也得放於device中,否则 training and validation 时,会找不到资料。可自行实验。
下一回,我们继续往下说明 。
<<: 【Day29】从小菜鸟使用React到现在踩到的地雷经验谈 (ᗒᗣᗕ)՞
吃角子老虎机 教学原文参考:吃角子老虎机 这篇文章会介绍如何使用「函式」、「计次回圈」、「随机取数」...
接下来要写自我介绍页, 自我介绍页使用 routes/web.php 里面的 Route::grou...
再来说说我大学到转职前的经历好了。 大学时期,大一大二时其实不算个认真的大学生,就是个标准出社会人士...
iOS工程师面试深入浅出(OC)- @property 使用方法?Copy 什麽时候用? 如果本来是...
在前几期,把AR装置的发展过程大略的描述了一下,从厚重的头戴式装置到可供个人购买的AR眼镜,但这中间...