Day20-pytorch(3)设置model、使用pytorch设置loss function及updata参数

我们透过简单的regression来认识如何设置model及如何使用pytorch设置loss function及updata参数
先设个范例资料,总共4笔资料
x为输入,特徵值有两个
y为要训练的目标
https://ithelp.ithome.com.tw/upload/images/20210904/20140416gjvIBINJbD.png

设置模型时会先import torch.nn并命名为nn
https://ithelp.ithome.com.tw/upload/images/20210904/20140416qrXpmPCF2Y.png

设置模型:
建立一个class并继承nn.Module
在__init__需输入super(Model,self).init()
设置self.net为我们自己的模型
透过nn.Sequential方法来建立,之後在後方传入你想要的模型样子
nn.Linear用来加入线性的模型,後方参数传入模型输入值与输出值
第一个输入并符合输入资料的资料个数,所以我一开始先传入x.shape[1]
线性模型後接上activation function,这里我都使用ReLU
def forward後方必须传入输入参数,会回传model计算结果
https://ithelp.ithome.com.tw/upload/images/20210904/20140416Y6JkAKMKcX.png

设置迭代次数及learning rate:
n_iter为我们想要训练的次数
learning rate在这里设置,之後会传入optimizer後方参数,等一下你就看得到了
https://ithelp.ithome.com.tw/upload/images/20210904/20140416vVdXyPbWSG.png

设置loss function及optimizer:
我将critirion设为loss function的变数名称,後方传入pytroch内已有的方法nn.MSELoss()
表示此loss function是计算mean-square error
optimizer设为updata参数的变数名称,後方传入pytorch内已有的方法,这里我选择adam演算法
後方需传入model参数及learning rate
https://ithelp.ithome.com.tw/upload/images/20210904/201404161nrZN3EBR8.png

开始训练资料
pre为model计算完结果
将pre与训练目标y传入critirion来算出loss
optimizer.zero_grad()一定要传入,要归零上一次的微分结果
loss.backward算出这次的微分结果
optimizer.step()根据传入的演算法updata参数,这里就是使用上面设置的adam演算法
https://ithelp.ithome.com.tw/upload/images/20210904/2014041661SMSGKPfY.png

查看训练结果
可看出我们训练结果已与目标非常相近
https://ithelp.ithome.com.tw/upload/images/20210904/20140416XfvKQFWaJ5.png

送上colab连结,可自行在上面多做点练习更加熟悉pytorch
https://colab.research.google.com/drive/1CiullFEJa1vGMrH-qNeWgDUwXHJ_PrfB?usp=sharing


<<:  AI ninja project [day 22] 变分自动编码器 Variational Autoencoder

>>:  Day8_HTML语法5

ui 框架説明

我比较熟悉的ui是qt的,但是框架类似,下面就分几步讲解,我是如何在一个自动化项目中使用UI的: 首...

Ruby on Rails 模组(Module)

如果我有一个小猫类别,我想要这个小猫类别有飞行功能,你会怎麽做? 直接写一个有飞行功能的小鸟类别,然...

Day-14 Pytorch 的 Gradient 计算

之前我们看过用 Python 计算 Gradient 必须要手动计算偏微分之後,才有办法算出 那如...

InnoDB的表格空间-Part1(区、段、区的分类、段的结构)

透过前面的内容大家知道表格空间是一个抽象的概念,对系统表格空间来说,对应着档案系统中一个或多个档案;...

D2 - 环境安装 (Miniconda & PyCharm)

之前装Anaconda实在太占空间我看了一下我现在大概占了快5G 这次想来试试看轻量安装的Minic...