Python很吃版本,我使用的版本如下:
Python 3.7.4
sklearn 0.21.3
|
以下介紹使用scikit-learn使用SVM(Support vector machine)進行訓練與預測的方式,所使用的資料集是scikit-learn內建的,是鳶尾花資料集,共分三類。
鳶尾花的資料可以在wikipedia上找到,而且第一筆資料跟scikit-learn提供的一模一樣。
https://zh.wikipedia.org/zh-tw/%E5%AE%89%E5%BE%B7%E6%A3%AE%E9%B8%A2%E5%B0%BE%E8%8A%B1%E5%8D%89%E6%95%B0%E6%8D%AE%E9%9B%86
Step 1: 首先,我們先匯入該使用的lib:
# reference: https://pyecontech.com/2020/04/11/python_svm/
from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import train_test_split
|
Step 2: 將鳶尾花資料集叫出來使用,每筆資料的特徵向量放X,答案放y,這邊我把第一筆特徵向量印出來看看,檢查看看是不是跟wikipedia網站上寫的是一樣的:
#=== 下載練習資料 ===
#採用sklearn所提供的鳶尾花資料集
iris=datasets.load_iris()
X=iris.data
y=iris.target
print(X[0])
|
結果:
[5.1 3.5 1.4 0.2]
|
可以往上看一下,真的是一模一樣。
Step 3:將樣本分成訓練跟測試集,使用train_test_split函式來切割訓練與測試集。
#=== 將樣本分成訓練集與測試集 ===
#test_size=0.2代表的就是測試集佔所有樣本的20%, random_state參數設定為0的方式來固定種子,random_state可以給任意整數值。
X_train, X_test, y_train, y_test = \
train_test_split(X, y,test_size=0.2,random_state=0)
# x: 測試資料向量, y: 測試資料正確答案
print(X_train[0])
print(y_train[0])
|
結果:
[6.4 3.1 5.5 1.8]
2
|
Step 4:訓練與建立模型,使用svm的SVC是給定參數,fit是用來進行訓練
#=== 建立模型 ===
#C為懲罰係數
#gamma參數決定支援向量的多寡
clf=svm.SVC(kernel='rbf',C=1,gamma='auto')
clf.fit(X_train,y_train)
|
Step 5:模型預測與自我回歸測試
#=== 預測 ===
clf.predict(X_test)
#=== 準確度分析 ===
print(clf.score(X_train,y_train))
print(clf.score(X_test, y_test))
|
結果:
0.975
1.0
|
0.975代表訓練資料集自我回歸測試是97.5%的準確度,測試資料集則是100%的準確度。
完整程式碼如下:
# reference: https://pyecontech.com/2020/04/11/python_svm/
from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import train_test_split
#=== 下載練習資料 ===
#採用sklearn所提供的鳶尾花資料集
iris=datasets.load_iris()
X=iris.data
y=iris.target
print(X[0])
#=== 將樣本分成訓練集與測試集 ===
#test_size=0.2代表的就是測試集佔所有樣本的20%, random_state的方式來固定種子
X_train, X_test, y_train, y_test = \
train_test_split(X, y,test_size=0.2,random_state=0)
# x: 測試資料向量, y: 測試資料正確答案
print(X_train[0])
print(y_train[0])
#=== 建立模型 ===
#C為懲罰係數
#gamma參數決定支援向量的多寡
clf=svm.SVC(kernel='rbf',C=1,gamma='auto')
clf.fit(X_train,y_train)
#=== 預測 ===
clf.predict(X_test)
#=== 準確度分析 ===
print(clf.score(X_train,y_train))
print(clf.score(X_test, y_test))
|
留言列表