模型的内容08 test()

这个章节,我们将谈到 test()的部分。

进入主题之前,我们要注意的是,test_loader是固定的1000笔资料直接使用(没epoch),所以不download参数。因此,它并不带有batch index 这栏位。

首先,先宣告这是validation的作业,而後将一些变数清0。

    model.eval()
    test_loss = 0
    correct = 0

由於我们在这个阶段,目的是测试model的准确性,也就利用model做一种推估,再比较和实际值的差异,进而得到其准确性。所以,我们不需要 backword and optimizer来优化 weights。因此,我们宣告暂时不需要gradient。

with torch.no_grad():

接下来,每次读一笔资料。读入的资料,放入device中。进而喂入model得到预测结果output。

            data, target = data.to(device), target.to(device)
            output = model(data)

接着,我们计算 output 和 target 之间的 loss,累计之。同时,也计算推论正确的次数 correct,亦累计之。

            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
	 #
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
	 #
            correct += pred.eq(target.view_as(pred)).sum().item()

Pred and correct大致说明一下。假设我们预测手写数字为5,那麽output中机率最高的应为output[5],所以pred=5。若预测的答案和target (5)相同,correct就累加1。

以上章节,讲完本机的部分。
接下来会谈及云端的部分。


<<:  Day 18: 人工智慧在音乐领域的应用 (AI作曲-基因演算法二)

>>:  Kotlin Android 第28天,从 0 到 ML - TensorFlow Lite -姿态估计 (Pose estimation)

予焦啦!在 ethanol 中启用虚拟记忆体

本节是以 Golang 上游 4b654c0eeca65ffc6588ffd9c99387a7e4...

SQL Server 安全性设定 - 心得分享

DBA Bootcamp 大多数 SQL Server 的服务器验证都是设为 mixed mode,...

Day10韩国街头必吃小吃-韩式起士辣炒年糕

韩剧中总是会出现的辣炒年糕,随着机智的医生生活第2季完结,我爱辣炒年糕同好会也结束了,这次准备升级版...

Day 28: gulp 是怎麽运作的

要讲到 gulp 怎麽运作的就不得不讲到 vinyl 跟 Node.js 的 stream viny...

Day.16 「重复的事情,交给程序去做!」 —— JavaScript 循环回圈

前面学习到了条件判断式,接着我们来学习有点危险的循环回圈,好啦!也没那麽夸张~只是写不好,容易进入...