import torch
# f = w * x
# f = 2 * x, we set w as 2
x = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32)
y = torch.tensor([2, 4, 6, 8, 10, 12], dtype=torch.float32)
# init weight
# 这边要注意,我们希望 Pytorch 帮我们计算更新的 Gradient 变数是 w,所以一定要开 requires_grad 在这个变数上
w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
# model prediction
def forward(x):
return w * x
# set up loss function as mean square error
def loss(y, y_predicted):
return ((y_predicted-y) ** 2).mean()
# Training
learning_rate = 0.01
n_iters = 10
for epoch in range(n_iters):
# perdiction = forward pass
y_pred = forward(x)
# loss
l = loss(y, y_pred)
# gradient descent is where calculate gradient and update parameters
# so gradient descent here includes gradients and update weights
# 原本在 Python 的 example 还需要自己建立 Gradient 函式
# gradients = backward pass
l.backward() # calculate dl/dw
# update weights
with torch.no_grad():
w -= learning_rate * w.grad
if epoch % 1 == 0:
print(f'epoch {epoch + 1}: w = {w:.3f}, loss = {l:.8f}')
# zero gradients,要记得归零每次运算的 gradients,否则会累加
w.grad.zero_()
print(f'Prediction after training: f(5) = {forward(5): .3f}')
<<: 【Day 18】Complexity & Graphs
>>: 用React刻自己的投资Dashboard Day15 - 投资Dashboard 2.0版 Wireframe
这个暑假就像开头第一篇说的,应该是大部分人度过最长的一个暑假,我原本也没什麽目标,打算好好休养生息,...
元件介绍 Breadcrumb 是一个导航元件,用於显示当前系统层级结构中的路径位置,并且点击路径能...
今天我们要介绍的是机器学习,所谓的机器学习是指着重於训练电脑从资料中学习,并根据经验改进且在机器学习...
现在有资料,只差介面了。 建立 base-window 组件 虽然每个视窗功能都不同,但是视窗外框功...
Buttons template message = { "type": &qu...