这章节,我们将说明 train()的细部。
程序部分如下:
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:
break
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
首先,传进来的函数参数有:args, model, device, train_loader, optimizer, epoch。这些参数,在前面章节已经叙述过了,在此省略。
首先,我们要宣告model要做的事为训练。
model.train()
接下来,for loop 中,train_loader会一笔一笔的将资料汇进来,每笔资料的格式为 [batch_idx, (data, target)]。每批资料,都会有自己的索引,batch_index。If 的部分,判断资料是否还在,不存在则跳离。
接着,我们将一笔train data、target(label)放入device中。(因为,model也在device中。大家要再一起,才能执行。)
data, target = data.to(device), target.to(device)
然後,记得将optimizer归零,否则会一直累加!
optimizer.zero_grad()
再来将data喂入model中,得到 output。
output = model(data)
接着计算output and target的差距,因为资料为 HxWxC是2D资料,适合用 nll_loss来计算 loss。
loss = F.nll_loss(output, target)
而後回去计算神经网路丛里的weights and bias。
loss.backward()
接下来,根据learn rate and momentum,进行优化计算。(就是weights and bias的调整,让loss得到最小值)
优化完毕後,将新的weights and bias取代原有的。
optimizer.step()
再来的if,主要是在用於判断执行几圈,才进行logger的动作。得到的结果,大致类似如此:
[2021-09-18 15:34:21] INFO (mnist_AutoML/MainThread) Train Epoch: 10 [0/60000 (0%)] Loss: 0.003544
下一个章节,我们将谈到 test()
<<: [Day 19] Facial Recognition: 使用孪生网路做辨识
>>: Day20 Analysis of Algorithms(Ⅱ)
Given a string containing just the characters '(' ...
建立资料仓库是一个解决企业资料问题应用的过程,是企业资讯化发展到一定阶段必不可少的一步,也是发展资料...
前言 上一篇玩完HTTP Method後,接着来玩Request的Data Length吧! 正文 ...
接着,请回到[Microsoft Azure]的Home,在[Recent resources]处&...
GitHub Issue 有点像是专案管理系统内管理工作事项的功能,但它能达到功能更多:无论是个人或...