漸く scikit-learn での iris dataset の分類を試みる。今回は scikit-learn 1.0.2 を使った。
といってもよくある実装の通りで
from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier # データセットのロードと訓練データとテストデータへの分割 iris = datasets.load_iris() X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=1, stratify=y) # 決定木による分類器の作成と訓練 tree = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=0) tree.fit(X_train, y_train) # 精度確認 print(tree.score(X_test, y_test))
0.9555555555555556
で終わりである。精度については以下のように混同行列とその対角線上の要素の和、つまりはトレースを明示的に使って以下のようにも計算できる:
from sklearn.metrics import confusion_matrix y_pred = tree.predict(X_test) conf_mat = confusion_matrix(y_test, y_pred) print(np.trace(conf_mat) / len(y_test))
決定境界を可視化すると以下のようになる。この辺はライブラリのバージョンに多少依るかもしれないし、criterion='gini'
や max_depth
の値でも変化するとは思う。
或はコードで書くと以下のようにもできる。
r = export_text(tree, feature_names=iris['feature_names'][2:]) print(r)
|--- petal width (cm) <= 0.75 | |--- class: 0 |--- petal width (cm) > 0.75 | |--- petal length (cm) <= 4.75 | | |--- class: 1 | |--- petal length (cm) > 4.75 | | |--- petal length (cm) <= 5.15 | | | |--- class: 2 | | |--- petal length (cm) > 5.15 | | | |--- class: 2
訓練データとしては長さ 105 のデータセットであるが、これを使う量を幾らか調整すると、1% だけ使う場合は以下のように問答無用ですべてを setosa として分類してしまった。概ね均一なデータセットなので、1/3 が正解するということで、汎化性能は 0.33 であった。
|--- class: 0
データセットの 5% を使うくらいの時は成長過程が見えて面白い。汎化性能は 0.67 であった。
|--- petal width (cm) <= 1.15 | |--- class: 0 |--- petal width (cm) > 1.15 | |--- class: 2
データの 30% を使う時がとても分かりやすい決定境界が出力され、勉強になる。汎化性能は 0.96 程度にまで到達し、これ以降は対して変化しない。max_depth
を深くすれば向上しそうだが、ニューラルネットで層を深くするのと同様、説明しにくさは向上すると思われる。
|--- petal length (cm) <= 4.75 | |--- petal length (cm) <= 2.60 | | |--- class: 0 | |--- petal length (cm) > 2.60 | | |--- class: 1 |--- petal length (cm) > 4.75 | |--- class: 2
データは 50% くらい使う時が汎化性能が高く出ていて、0.98 くらいまで出ていた。本当にそういう傾向があるのかは交差検証などで確認しないと分からないが、たまたまそういう感じであった。
|--- petal length (cm) <= 4.75 | |--- petal length (cm) <= 2.60 | | |--- class: 0 | |--- petal length (cm) > 2.60 | | |--- class: 1 |--- petal length (cm) > 4.75 | |--- petal width (cm) <= 1.55 | | |--- petal length (cm) <= 5.00 | | | |--- class: 1 | | |--- petal length (cm) > 5.00 | | | |--- class: 2 | |--- petal width (cm) > 1.55 | | |--- class: 2