Day-17 Pytorch 的 Linear Regression

  • 在前面我们学习过了 Pytorch 的基础用法,今天我们来正式依照 Pytorch Model Class 的撰写规则,正式来撰写一个 Linear Regression Model
  • 上 Code~

Linear Regression Classes

  • Linear Regression 的整个架构会由几个东西组成?很简单,就是一个 linear function
  • 那 Pytorch Model 的写法会怎麽去写?
class LinearRegression(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        
        # define layers
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
    
        return self.linear(x)
  • 那为什麽要写成这样子呢?一般类神经网路是一个庞杂的结构,会有许多不同的 layers,我们一般会在 __init__ 里面定义了所有会使用的 layer functions,并且实际传递的方式会由 forward() 里面定义了类神经网路的结构和传递状况,因此才会这样撰写,我们会在之後的 CNN 和基础类神经网路的示范中示范到

Linear Regression Example

import torch
import torch.nn as nn
import numpy as np
from torch.optim import optimizer
from sklearn import datasets
import matplotlib.pyplot as plt

# 0) prepare data
feature_numpy, target_numpy = datasets.make_regression(n_samples=100, n_features=1, noise=20, random_state=1234)

feature = torch.from_numpy(feature_numpy.astype(np.float32))
target = torch.from_numpy(target_numpy.astype(np.float32))
target = target.view(target.shape[0], 1)

n_samples, n_features = feature.shape

# 1) model
class LinearRegression(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        
        # define layers
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
    
        return self.linear(x)

model = LinearRegression(n_features, 1)

# 2) loss and optimizer
learning_rate = 0.01
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 3) training loop
epochs = 100
for epoch in range(epochs):
    # forward pass and loss
    y_predicted = model(feature)
    loss = criterion(y_predicted, target)

    # backward pass
    loss.backward()

    # update
    optimizer.step()
    
    # init optimizer
    optimizer.zero_grad()

    if (epoch + 1) % 10 == 0:
        print(f'epoch: {epoch+1}, loss = {loss.item(): .4f}')

# show in image
predicted = model(feature).detach().numpy()
plt.plot(feature_numpy, target_numpy, 'ro')
plt.plot(feature_numpy, predicted, 'b')
plt.show()

outputs

epoch: 10, loss =  5616.6792
epoch: 20, loss =  3864.4285
epoch: 30, loss =  2691.9751
epoch: 40, loss =  1907.3096
epoch: 50, loss =  1382.0626
epoch: 60, loss =  1030.3954
epoch: 70, loss =  794.8961
epoch: 80, loss =  637.1581
epoch: 90, loss =  531.4833
epoch: 100, loss =  460.6736

每日小结

  • 现在会感觉写成 Class 是在多此一举,但是在後面的范例中应该就会感受到这样写的重要性,今天算是一个程序进化的过程,让大家感受一下从完全没有 Pytorch 套件的纯 python code ,到慢慢利用 Pytorch 解决每一个环节的问题,到最後完整 Pytorch 撰写的方式,这里就是希望大家感受一下使用 Pytorch 真的可以帮助使用者更专注在重要的结构上,而非基础的数学问题
  • 明天我们在来看看之前写过的 Logistic Regression 改写成 Pytorch 的过程

<<:  [Day_18]回圈与生成式 - (4)

>>:  滤镜-30天学会HTML+CSS,制作精美网站

【D22】制作讯号灯之反思:观察讯号灯与9/22大盘关系

前言 今天加权指数开低,维持一个大跌,来观察讯号灯和大盘、个股的关系,来验证我们的讯号灯能不能参考。...

电子书阅读器上的浏览器 [Day27] 无痕模式

原先的 browser 实作就已经包含了无痕模式的细部功能,像是禁止使用 Cookie,和不记录浏览...

## D21 - 彭彭的课程# Python 乱数与统计模组(1)

今天睡眠品质极差 整个人像殭屍一样实在是not my day 好今天是来看看这个乱数跟统计模组 感觉...

iOS APP 开发 OC 第十一天,使用 typedef 简化 block

tags: OC 30 day 问题:简化block变量的时候,要写好大一串,类型好长。 typed...

Android Studio初学笔记-Day6-EditText

EditText(输入框) 是个能与使用者互动的一个元件,我觉得也开始让程序变得稍微有点层次了,其实...