Day 20 : 线性回归与罗吉斯回归

线性回归(Linear Regression)

如果我们有数据 (x, y) ,假设 x 是年资、y 是薪资,我们想找出其中的关联 w 和 b (y = w * x + b)

我们就可以依照这些数据绘制出一条线,来描述这些数据

而这些线是我们透过学习找到一个最小 error 去拟合训练资料

产生资料

import numpy as np
import matplotlib.pyplot as plt

# 我们自己随机产资料
np.random.seed(0)
noise = np.random.rand(100, 1)
x = np.random.rand(100, 1)
y = 8 * x + 100 + noise
# plot
plt.scatter(x, y, s=10)
plt.xlabel('x')
plt.ylabel('y')
plt.show()

建立 Linear Regression 模型

from sklearn.linear_model import LinearRegression
# 建立模型
linearMmodel = LinearRegression(fit_intercept=True)
# 使用训练资料训练模型
linearMmodel.fit(x, y)
# 使用训练资料预测
predicted = linearMmodel.predict(x)
from sklearn import metrics
print('R2 score: ', linearMmodel.score(x, y))
mse = metrics.mean_squared_error(y, predicted)
print('MSE score: ', mse)
>>> R2 score:  0.9831081424561687
    MSE score:  0.08275457812228725

模型预测长相

plt.scatter(x, y, s=10, label='True')
plt.scatter(x, predicted, color="r",s=10, label='Predicted')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

# 分别储存在linearMmodel.coef_[0] 和 linearMmodel.intercept_中
coef = linearMmodel.coef_ 
intercept = linearMmodel.intercept_

print("斜率w = ", coef[0][0])
print("截距b = ", intercept[0])
>>> 斜率w =  7.931123354540897
    截距b =  100.50916633941445

罗吉斯回归(Logistic Regression

虽然跟这里探讨的是回归模型,但是大家必须厘清一点的是罗吉斯回归是应用在分类问题

用以下表格来说明

模型名称 预测标签 应用场景 公式
线性回归 数值 适用於预测数值型,例如预测物价指数 https://chart.googleapis.com/chart?cht=tx&chl=%24%24%5Csum_%5Climits%7Bi%7D(%7Bw%7D_i%7Bx%7D_i%20%2B%20b)%24%24
罗吉斯回归 介於0到1的机率、布林值 适用於二元分类,例如吸菸是否会得到癌症的机率、信用卡评分模型等等 https://chart.googleapis.com/chart?cht=tx&chl=%24%24%5Csigma(%5Csum_%5Climits%7Bi%7D(%7Bw%7D_i%7Bx%7D_i%20%2B%20b))%24%24

原理

在介绍罗吉斯的公式前,我们需要先了解「胜算比」(odds radio)是什麽,它是指对特定事件出现的比率。

公式是:

P表示「正事件」发生的机率,然而正事件不一定代表是好事情,也可以指的是出现癌症的事件(想预测的事件)。

罗吉斯回归的公式:
其中 sigmoid 的函数为:

我们利用线性回归输出的结果来进行二元分类(输出大於0.5分到1、小於0.5就分到0)。

图片来源网址

  • 优点:
    • 不需要假设分配类型
    • 快速可以得到结果
    • 了解各类别的分类机率
  • 缺点:
    • 无法解决非线性问题
    • 不太能处理大量的特徵,容易造成过度拟合

Sigmoid 函数

先来看看 Sigmoid 到底长什麽样子

import matplotlib.pyplot as plt
import numpy as np

def sigmoid(z):
    return 1/(1+np.exp(-z))

z = np.arange(-10, 10, 0.1)
phi_z = sigmoid(z)
plt.plot(z, phi_z)
plt.axhline(y=1, ls='dotted', color='black')
plt.axhline(y=0.5, ls='dotted', color='black')
plt.axhline(y=0, ls='dotted', color='black')
plt.show()

实作罗吉斯回归模型

接着来实作罗吉斯回归,这边会应用到鸢尾花资料集,使用 sklearn 就可以拿到资料集。

罗吉斯回归程序码

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import datasets
# 取出鸢尾花资料
iris = datasets.load_iris()
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
iris_df['target'] = iris.target
iris_df['target_name'] = iris_df['target'].map({0: "setosa", 1: "versicolor", 2: "virginica"})
# 定义 X 和 Y
X = iris_df.drop(labels=['target_name', 'target'] ,axis=1)
y = iris_df['target'].values
# 进行资料集分类
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state=0)
# 特徵缩放
from sklearn.preprocessing import StandardScaler
sc_X = StandardScaler()
X_train = sc_X.fit_transform(X_train)
X_test = sc_X.transform(X_test)
# 模型拟合
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression(random_state=0)
classifier.fit(X_train, y_train)
# 模型预测
y_pred = classifier.predict(X_test)
# 预测成功的比例
print('训练集: ', classifier.score(X_train,y_train))
print('测试集: ', classifier.score(X_test,y_test))
>>> 训练集:  0.9642857142857143
    测试集:  1.0

混淆矩阵

# Making the Confusion Matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
>>> array([[15,  0,  0],
          [ 0, 11,  0],
          [ 0,  0, 12]])

F1-Score

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred))

github 程序码

更详细可以请参考连结


<<:  android studio 30天学习笔记-day 17-TabLayout+TabItem

>>:  虹语岚访仲夏夜-18(打杂的Allen篇)

DAY16 - [JS] 扩充功能 - 倒数计时,暂停、开始、结束

今日文章目录 需求说明 事前准备 遇到问题 需求说明 输入时间改成分钟 增加功能:暂停、开始、结束...

[30天 Vue学好学满 DAY7] 监听器(Watch)

Watch 监听器 具比较传(old & new) 无回传值(return) 监听变数发生异...

搞懂 P2P 技术 (2) - STUN x TURN x ICE

前言 上一篇介绍完中心化、去中心化、分布式网路以及 IPv4、NAT、NAT 类型,但我们依旧还有些...

[Day 26] 永和美食纪录-呈信传统鹅肉店 文化店

前言 转眼间,国庆连假已经要结束了,不晓得大家有没有好好的放松自己的身心,有些店家也因为连假的缘故而...

31.Module

当应用变得非常复杂时,store 对象就有可能变得相当臃肿。 为了解决以上问题,Vuex 允许我们将...