Day 21 : SVM

原理说明

SVM (support vector machine 支援向量机),是在特徵空间中找到一个分离超平面,也就是「决策边界」(decision boundary)。

我们可以透过这个决策边界(红色虚线)将资料分成不同类别,最佳化的目标是「边界」(margin)。

离灰线最近的是支援向量(support vector)。

当然不是每个情况都可以靠这种一刀两断的方式进行分类,SVM 可以将资料投影到高维度空间,在高维度空间找到超平面进行分割。

通常使用的 kernal 核心转换有 RBF Kernal 高斯转换、高次方转换等等

SVM 操作示范

资料来源
这次我想要预测顾客最後会不会购买书,可以使用的特徵因子有性别、年龄、薪水及是否为VIP注记。这边搭配 PCA 做降维,选择 n_components=2(最後才可以画图呀XDD)不做降维不好画图啊!

线性 SVM

# 训练线性 SVM 并预测结果
kernel = 'linear'
model = SVC(kernel=kernel)
model.fit(dx_train, dy_train)
predict = model.predict(dx_test)
test_score = model.score(dx_test, dy_test) * 100
plt.figure(figsize=(8, 8))
plt.rcParams['font.size'] = 14
plt.title(f'SVM {kernel} (accuracy={test_score:.1f}%)')
plt.scatter(*dx_test.T, c=predict, cmap='tab10', s=100)
plt.scatter(*dx_test.T, c=dy_test, cmap='Set3', s=35)
# 求出超平面与边界
x_min = np.amin(dx_test.T[0])
x_max = np.amax(dx_test.T[0])
y_min = np.amin(dx_test.T[1])
y_max = np.amax(dx_test.T[1])
XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
Z = model.decision_function(
    np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
# 画出超平面与边界
plt.contour(XX, YY, Z, colors=['grey', 'coral', 'grey'],
            linestyles=['--', '-', '--'], linewidths=[2, 2, 2],
            levels=[-1, 0, 1])
plt.grid(True)
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
plt.tight_layout()
plt.show()

非线性 SVM(RBF)

在 kernel 还可以选择 poly、sigmoid,最後发现 rbf 成效最好!

# 训练非线性 SVM 并预测结果
kernel = 'rbf'
model = SVC(kernel=kernel)
model.fit(dx_train, dy_train)
predict = model.predict(dx_test)
test_score = model.score(dx_test, dy_test) * 100
plt.figure(figsize=(8, 8))
plt.rcParams['font.size'] = 14
plt.title(f'SVM {kernel} (accuracy={test_score:.1f}%)')
plt.scatter(*dx_test.T, c=predict, cmap='tab10', s=100)
plt.scatter(*dx_test.T, c=dy_test, cmap='Set3', s=35)
# 求出超平面与边界
x_min = np.amin(dx_test.T[0])
x_max = np.amax(dx_test.T[0])
y_min = np.amin(dx_test.T[1])
y_max = np.amax(dx_test.T[1])
XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
Z = model.decision_function(
    np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
# 画出超平面与边界
plt.contour(XX, YY, Z, colors=['grey', 'coral', 'grey'],
            linestyles=['--', '-', '--'], linewidths=[2, 2, 2],
            levels=[-1, 0, 1])
plt.grid(True)
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
plt.tight_layout()
plt.show()

github 程序码

更详细可以请参考连结


<<:  【在 iOS 开发路上的大小事-Day21】透过 Firebase 来管理使用者 (Sign in with Apple 篇) Part1

>>:  EP18 - [TDD] 订单 API 串接 (1/2)

I Want To Know React - 中场休息

铁人炼成,回顾三十天 三十天过去了,没想到我竟然成功完成铁人赛了! 上次铁人赛完赛心得的第一句话是 ...

D-12, Ruby 正规表达式(二) 量词 、锚 && Reverse Vowels of a String

昨天的重点复习/./就是一个最简单的正规表达式。 先认识一下match与=~。 match回传匹配的...

容器化基本概念

容器映像(container image)是开发人员创建并注册的程序包(package),包含在容...

使用Quartz.Net达成Asp.Net Core长时程执行

Web应用程序本身的机制并不适合用来作为执行需要长时程运行的需求,而这类需求却很常见,而常见的解决方...

Day 06 Python 的特点

经过了前面几天的基本教学,相信大家都对 Python 有了基本的认识,也应该有点累了,所以今天来讲一...