模型的内容04 def main()

上一章节研究完class net(…),这一章节我们继续研究 def main(args)这部分。

def main(args):
    # define data directory and device (CPU or GPU)
    use_cuda = not args['no_cuda'] and torch.cuda.is_available()
    torch.manual_seed(args['seed'])
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    data_dir = args['data_dir']

    # --- data loader
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args['batch_size'], shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True, **kwargs)
    # --- end  data loader

    hidden_size = args['hidden_size']
    model = Net(hidden_size=hidden_size).to(device)
    optimizer = optim.SGD(model.parameters(), lr=args['lr'],
                          momentum=args['momentum'])
    for epoch in range(1, args['epochs'] + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test_acc = test(args, model, device, test_loader)
        # report intermediate result
        nni.report_intermediate_result(test_acc)
        logger.debug('test accuracy %g', test_acc)
        logger.debug('Pipe send intermediate result done.')
    # report final result
    nni.report_final_result(test_acc)
    logger.debug('Final result is %g', test_acc)
    logger.debug('Send final result done.')

一开始的部分,主要是定义 data directory and device for CPU or GPU。

接下来我们看一下 data loader的部分。 data loader可以帮我们整理转换资料外(ToTensor()),还可以依照我们的需要,一批批的吐出来。例如train_loader里的 batch_size。说明一下,这里的test_loader,是训练模型时,用来做validation用的。个人比较喜欢用valid_loader一词。

资料有时候很分散,此时会影响训练及验证速度(计算速度、判别速度),所以需要正规化。 transforms.Normalize((0.1307,), (0.3081,)),由於资料是黑白的,RGB channel 只有1个,所以只有1个数字。Mean=(0.1307,),Std= (0.3081,)。至於里面的数字为何为此,我也不知。也许可以从data.describe()的统计表中得知一二吧!

另外,shuffle=True,主要是让每批资料,都能很平均的取样,以免资料取样产生偏颇,让模型导致无效!

初始化model为神经网路 class Net时,要设定其隐藏层的大小,hidden_size。即

model = Net(hidden_size=hidden_size).to(device)。

请注意,model 是放在device中执行,所以 training and validation 时,loader过来的资料也得放於device中,否则 training and validation 时,会找不到资料。可自行实验。

下一回,我们继续往下说明 。


<<:  【Day29】从小菜鸟使用React到现在踩到的地雷经验谈 (ᗒᗣᗕ)՞

>>:  Day 14 资料表之间的关联栏位

Day28 ( 游戏设计 ) 吃角子老虎机

吃角子老虎机 教学原文参考:吃角子老虎机 这篇文章会介绍如何使用「函式」、「计次回圈」、「随机取数」...

[Day 36] 自我介绍後台及前台(五) - 前台的自我介绍页

接下来要写自我介绍页, 自我介绍页使用 routes/web.php 里面的 Route::grou...

002-新鲜人

再来说说我大学到转职前的经历好了。 大学时期,大一大二时其实不算个认真的大学生,就是个标准出社会人士...

iOS工程师面试深入浅出(OC)- @property 使用方法?Copy 什麽时候用?

iOS工程师面试深入浅出(OC)- @property 使用方法?Copy 什麽时候用? 如果本来是...

Day18 AR头戴式装置 Apple也来凑一咖

在前几期,把AR装置的发展过程大略的描述了一下,从厚重的头戴式装置到可供个人购买的AR眼镜,但这中间...