Day-15 Pytorch 的 Regression 实作

  • 那我们之前看过了 Python 的 Easy Regression 实作,昨天也看过了 Pytorch 如何做到 Gradient Calculation,那我们今天就拿一样的 Example 来看看如果事 Pytorch 会长怎样吧~
  • 本篇范例是对应 Day-05 的 Easy Regression Example 去做 Framework 上面的比较

直接上 Code

  • 在这边我们做了一个大更新,就是把 Graient 的计算交给了 Pytorch,可以发现复杂的微分工作已经交给了 Backpropagation 来处理
  • 这次的程序跟 Day-05 的最大差异就差在我们少了手工微分 Gradient function 了
import torch

# f = w * x
# f = 2 * x, we set w as 2
x = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32)
y = torch.tensor([2, 4, 6, 8, 10, 12], dtype=torch.float32)

# init weight
# 这边要注意,我们希望 Pytorch 帮我们计算更新的 Gradient 变数是 w,所以一定要开 requires_grad 在这个变数上
w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)

# model prediction
def forward(x):
    
    return w * x
    
# set up loss function as mean square error
def loss(y, y_predicted):

    return ((y_predicted-y) ** 2).mean()

# Training
learning_rate = 0.01
n_iters = 10

for epoch in range(n_iters):
    # perdiction = forward pass
    y_pred = forward(x)

    # loss
    l = loss(y, y_pred)

    # gradient descent is where calculate gradient and update parameters
    # so gradient descent here includes gradients and update weights
    # 原本在 Python 的 example 还需要自己建立 Gradient 函式
    # gradients = backward pass
    l.backward() # calculate dl/dw

    # update weights
    with torch.no_grad():
        w -= learning_rate * w.grad

    if epoch % 1 == 0:
        print(f'epoch {epoch + 1}: w = {w:.3f}, loss = {l:.8f}')
    
    # zero gradients,要记得归零每次运算的 gradients,否则会累加
    w.grad.zero_()

print(f'Prediction after training: f(5) = {forward(5): .3f}')

每日小结

  • 从上述的 Code 可以发现大体上概念并没有变化,但是这边的 Gradient 计算交给了 Backpropagation 的反向传递的概念,因此省略了我们自己微分 loss function 的部分
  • 从上面的程序可以发现,利用 Pytorch Framework 确实在程序的撰写上变得更加简洁,我们需要自己特别操作运算建立的函式也变少了,但是基本元素和运算概念是和纯 Python code 并无区别,这也是为甚麽我们前面要花那麽多的时间在介绍基本概念,因为就算是 Framework 也没有跳脱这个框架,因此基本概念仍然是非常重要的,我们後面会示范从头开始建立一个类神经网路的,可以敬请期待,我们明天会在示范 Pytorch 实作 Backpropagation

<<:  【Day 18】Complexity & Graphs

>>:  用React刻自己的投资Dashboard Day15 - 投资Dashboard 2.0版 Wireframe

DAY30-参赛心得

这个暑假就像开头第一篇说的,应该是大部分人度过最长的一个暑假,我原本也没什麽目标,打算好好休养生息,...

【Day18】导航元件 - Breadcrumb

元件介绍 Breadcrumb 是一个导航元件,用於显示当前系统层级结构中的路径位置,并且点击路径能...

Day 17 机器学习

今天我们要介绍的是机器学习,所谓的机器学习是指着重於训练电脑从资料中学习,并根据经验改进且在机器学习...

D09 - 打开第一扇窗

现在有资料,只差介面了。 建立 base-window 组件 虽然每个视窗功能都不同,但是视窗外框功...

DAY29 linebot message api-Template 介绍-2

Buttons template message = { "type": &qu...