这个章节,我们将谈到 test()的部分。
进入主题之前,我们要注意的是,test_loader是固定的1000笔资料直接使用(没epoch),所以不download参数。因此,它并不带有batch index 这栏位。
首先,先宣告这是validation的作业,而後将一些变数清0。
model.eval()
test_loss = 0
correct = 0
由於我们在这个阶段,目的是测试model的准确性,也就利用model做一种推估,再比较和实际值的差异,进而得到其准确性。所以,我们不需要 backword and optimizer来优化 weights。因此,我们宣告暂时不需要gradient。
with torch.no_grad():
接下来,每次读一笔资料。读入的资料,放入device中。进而喂入model得到预测结果output。
data, target = data.to(device), target.to(device)
output = model(data)
接着,我们计算 output 和 target 之间的 loss,累计之。同时,也计算推论正确的次数 correct,亦累计之。
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
#
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
#
correct += pred.eq(target.view_as(pred)).sum().item()
Pred and correct大致说明一下。假设我们预测手写数字为5,那麽output中机率最高的应为output[5],所以pred=5。若预测的答案和target (5)相同,correct就累加1。
以上章节,讲完本机的部分。
接下来会谈及云端的部分。
<<: Day 18: 人工智慧在音乐领域的应用 (AI作曲-基因演算法二)
>>: Kotlin Android 第28天,从 0 到 ML - TensorFlow Lite -姿态估计 (Pose estimation)
本节是以 Golang 上游 4b654c0eeca65ffc6588ffd9c99387a7e4...
DBA Bootcamp 大多数 SQL Server 的服务器验证都是设为 mixed mode,...
韩剧中总是会出现的辣炒年糕,随着机智的医生生活第2季完结,我爱辣炒年糕同好会也结束了,这次准备升级版...
要讲到 gulp 怎麽运作的就不得不讲到 vinyl 跟 Node.js 的 stream viny...
前面学习到了条件判断式,接着我们来学习有点危险的循环回圈,好啦!也没那麽夸张~只是写不好,容易进入...