close

Python很吃版本,我使用的版本如下:

Python 3.7.4

sklearn 0.21.3

以下介紹使用scikit-learn使用SVM(Support vector machine)進行訓練與預測的方式,所使用的資料集是scikit-learn內建的,是鳶尾花資料集,共分三類。

 

鳶尾花的資料可以在wikipedia上找到,而且第一筆資料跟scikit-learn提供的一模一樣。

https://zh.wikipedia.org/zh-tw/%E5%AE%89%E5%BE%B7%E6%A3%AE%E9%B8%A2%E5%B0%BE%E8%8A%B1%E5%8D%89%E6%95%B0%E6%8D%AE%E9%9B%86

圖片1.png

 

Step 1: 首先,我們先匯入該使用的lib:

# reference: https://pyecontech.com/2020/04/11/python_svm/
from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import train_test_split

 

Step 2: 將鳶尾花資料集叫出來使用,每筆資料的特徵向量放X,答案放y,這邊我把第一筆特徵向量印出來看看,檢查看看是不是跟wikipedia網站上寫的是一樣的:

#=== 下載練習資料 ===
#
採用sklearn所提供的鳶尾花資料集
iris=datasets.load_iris()
X=iris.data
y=iris.target

print(X[0])

結果:

[5.1 3.5 1.4 0.2]

可以往上看一下,真的是一模一樣。

 

Step 3:將樣本分成訓練跟測試集,使用train_test_split函式來切割訓練與測試集。

#=== 將樣本分成訓練集與測試集 ===
#test_size=0.2
代表的就是測試集佔所有樣本的20%, random_state參數設定為0的方式來固定種子,random_state可以給任意整數值。
X_train, X_test, y_train, y_test = \
    train_test_split(X, y,
test_size=0.2,random_state=0)
# x: 測試資料向量, y: 測試資料正確答案
print(X_train[0])
print(y_train[0])

結果:

[6.4 3.1 5.5 1.8]

2

 

Step 4:訓練與建立模型,使用svmSVC是給定參數,fit是用來進行訓練

#=== 建立模型 ===
#C
為懲罰係數
#gamma參數決定支援向量的多寡

clf=svm.SVC(kernel='rbf',C=1,gamma='auto')
clf.fit(X_train,y_train)

 

Step 5:模型預測與自我回歸測試

#=== 預測 ===
clf.predict(X_test)

#=== 準確度分析 ===
print(clf.score(X_train,y_train))
print(clf.score(X_test, y_test))

結果:

0.975

1.0

0.975代表訓練資料集自我回歸測試是97.5%的準確度,測試資料集則是100%的準確度。

 

完整程式碼如下:

# reference: https://pyecontech.com/2020/04/11/python_svm/
from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import train_test_split

#=== 下載練習資料 ===
#
採用sklearn所提供的鳶尾花資料集
iris=datasets.load_iris()
X=iris.data
y=iris.target

print(X[0])

#=== 將樣本分成訓練集與測試集 ===
#test_size=0.2
代表的就是測試集佔所有樣本的20%, random_state的方式來固定種子
X_train, X_test, y_train, y_test = \
    train_test_split(X, y,
test_size=0.2,random_state=0)
# x: 測試資料向量, y: 測試資料正確答案
print(X_train[0])
print(y_train[0])
#=== 建立模型 ===
#C
為懲罰係數
#gamma參數決定支援向量的多寡

clf=svm.SVC(kernel='rbf',C=1,gamma='auto')
clf.fit(X_train,y_train)


#=== 預測 ===
clf.predict(X_test)

#=== 準確度分析 ===
print(clf.score(X_train,y_train))
print(clf.score(X_test, y_test))

 

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 葛瑞斯肯 的頭像
    葛瑞斯肯

    葛瑞斯肯樂活筆記

    葛瑞斯肯 發表在 痞客邦 留言(0) 人氣()