模型的内容07 train()

这章节,我们将说明 train()的细部。
程序部分如下:

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:
            break
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args['log_interval'] == 0:
            logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

首先,传进来的函数参数有:args, model, device, train_loader, optimizer, epoch。这些参数,在前面章节已经叙述过了,在此省略。

首先,我们要宣告model要做的事为训练。

 model.train()

接下来,for loop 中,train_loader会一笔一笔的将资料汇进来,每笔资料的格式为 [batch_idx, (data, target)]。每批资料,都会有自己的索引,batch_index。If 的部分,判断资料是否还在,不存在则跳离。
接着,我们将一笔train data、target(label)放入device中。(因为,model也在device中。大家要再一起,才能执行。)

 data, target = data.to(device), target.to(device)

然後,记得将optimizer归零,否则会一直累加!

 optimizer.zero_grad()

再来将data喂入model中,得到 output。

  output = model(data)

接着计算output and target的差距,因为资料为 HxWxC是2D资料,适合用 nll_loss来计算 loss。

  loss = F.nll_loss(output, target)

而後回去计算神经网路丛里的weights and bias。

 loss.backward()

接下来,根据learn rate and momentum,进行优化计算。(就是weights and bias的调整,让loss得到最小值)
优化完毕後,将新的weights and bias取代原有的。

 optimizer.step()

再来的if,主要是在用於判断执行几圈,才进行logger的动作。得到的结果,大致类似如此:

[2021-09-18 15:34:21] INFO (mnist_AutoML/MainThread) Train Epoch: 10 [0/60000 (0%)]     Loss: 0.003544

下一个章节,我们将谈到 test()


<<:  [Day 19] Facial Recognition: 使用孪生网路做辨识

>>:  Day20 Analysis of Algorithms(Ⅱ)

(Hard) 32. Longest Valid Parentheses

Given a string containing just the characters '(' ...

详解资料仓库的实施步骤,实战全解!(1)

建立资料仓库是一个解决企业资料问题应用的过程,是企业资讯化发展到一定阶段必不可少的一步,也是发展资料...

[Day4] HTTP Request Smuggling - HTTP 请求走私

前言 上一篇玩完HTTP Method後,接着来玩Request的Data Length吧! 正文 ...

MS Azure ML02

接着,请回到[Microsoft Azure]的Home,在[Recent resources]处&...

GitHub 上讨论议题 - 建立第一个 Issue 与自订 Labels

GitHub Issue 有点像是专案管理系统内管理工作事项的功能,但它能达到功能更多:无论是个人或...