sickit-learn データセットの使い方

scikit-learnは機械学習、データ分析には必須のライブラリです。ここではデフォルトでscikit-learnに付随されているデータセットの使い方をメモしておきます。

sickit-learn 目次

  1. 公式データセット <= 本節
  2. データの作成
  3. 線形回帰
  4. ロジスティック回帰

github

  • jupyter notebook形式のファイルはこちら

google colaboratory

  • google colaboratory で実行する場合はこちら

環境

筆者のOSはmacOSです。LinuxやUnixのコマンドとはオプションが異なります。

筆者の環境

!sw_vers
ProductName:	Mac OS X
ProductVersion:	10.14.6
BuildVersion:	18G2022
!python -V
Python 3.7.3
import sklearn

sklearn.__version__
'0.20.3'

データ表示用にpandasもimportしておきます。

import pandas as pd

pd.__version__
'1.0.3'

画像表示用にmatplotlibもimportします。画像はwebでの見栄えを考慮して、svgで保存する事とします。

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import matplotlib.pyplot as plt

概要

scikit-learnは機械学習に必要なデータセットを用意してくれています。ここでは公式サイトにそってサンプルデータの概要を説明します。

  1. toy dataset
  2. 実際のデータセット

toy datasets

toyというのは、おそらく簡易的なデータで、実際の機械学習のモデル生成には不十分な量という意味だと思います。

boston住宅価格のデータ

  • target: 住宅価格
  • 回帰問題
from sklearn.datasets import load_boston

boston = load_boston()

最初なので少し丁寧にデータを見ていきます。

type(boston)
sklearn.utils.Bunch

データタイプはsklearn.utils.Bunch型だとわかります。

dir(boston)
['DESCR', 'data', 'feature_names', 'filename', 'target']

DESCR, data, feature_names, filename, targetのプロパティを持つ事がわかります 一つ一つの属性値を見ていきます。DESCRは、データに関する説明、filenameはデータのファイルの絶対パスなので省略します。

boston.data

実際に格納されているデータです。分析対象とする各特徴量が格納されています。説明変数とも言うようです。

boston.data
array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
        4.9800e+00],
       [2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
        9.1400e+00],
       [2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
        4.0300e+00],
       ...,
       [6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
        5.6400e+00],
       [1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
        6.4800e+00],
       [4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
        7.8800e+00]])

boston.feature_names

各特徴量の名前です。

boston.feature_names
array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
       'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='<U7')

boston.target

予測するターゲットの値です。公式サイトによるとbostonの場合は価格の中央値(Median Value)となります。

boston.target
array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
       18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
       15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
       13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
       21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
       35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
       19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
       20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
       23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
       33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
       21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
       20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
       23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
       15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
       17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
       25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
       23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
       32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
       34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
       20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
       26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
       31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
       22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
       42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
       36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
       32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
       20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
       20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
       22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
       21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
       19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
       32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
       18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
       16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
       13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3,  8.8,
        7.2, 10.5,  7.4, 10.2, 11.5, 15.1, 23.2,  9.7, 13.8, 12.7, 13.1,
       12.5,  8.5,  5. ,  6.3,  5.6,  7.2, 12.1,  8.3,  8.5,  5. , 11.9,
       27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3,  7. ,  7.2,  7.5, 10.4,
        8.8,  8.4, 16.7, 14.2, 20.8, 13.4, 11.7,  8.3, 10.2, 10.9, 11. ,
        9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4,  9.6,  8.7,  8.4, 12.8,
       10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
       15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
       19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
       29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
       20.6, 21.2, 19.1, 20.6, 15.2,  7. ,  8.1, 13.6, 20.1, 21.8, 24.5,
       23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9])

pandasで読み込みます。

df = pd.DataFrame(data=boston.data, columns=boston.feature_names)
df['MV'] = pd.DataFrame(data=boston.target)

df.shape
(506, 14)
df.head()

CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOBLSTATMV
00.0063218.02.310.00.5386.57565.24.09001.0296.015.3396.904.9824.0
10.027310.07.070.00.4696.42178.94.96712.0242.017.8396.909.1421.6
20.027290.07.070.00.4697.18561.14.96712.0242.017.8392.834.0334.7
30.032370.02.180.00.4586.99845.86.06223.0222.018.7394.632.9433.4
40.069050.02.180.00.4587.14754.26.06223.0222.018.7396.905.3336.2

となり、データ数が506個である事がわかります。また、各特徴量の統計量は以下の通りです。

df.describe()

CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOBLSTATMV
count506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000
mean3.61352411.36363611.1367790.0691700.5546956.28463468.5749013.7950439.549407408.23715418.455534356.67403212.65306322.532806
std8.60154523.3224536.8603530.2539940.1158780.70261728.1488612.1057108.707259168.5371162.16494691.2948647.1410629.197104
min0.0063200.0000000.4600000.0000000.3850003.5610002.9000001.1296001.000000187.00000012.6000000.3200001.7300005.000000
25%0.0820450.0000005.1900000.0000000.4490005.88550045.0250002.1001754.000000279.00000017.400000375.3775006.95000017.025000
50%0.2565100.0000009.6900000.0000000.5380006.20850077.5000003.2074505.000000330.00000019.050000391.44000011.36000021.200000
75%3.67708312.50000018.1000000.0000000.6240006.62350094.0750005.18842524.000000666.00000020.200000396.22500016.95500025.000000
max88.976200100.00000027.7400001.0000000.8710008.780000100.00000012.12650024.000000711.00000022.000000396.90000037.97000050.000000

アヤメのデータ

  • target: アヤメの種類
  • 分類問題
from sklearn.datasets import load_iris

iris = load_iris()
print(type(iris))
print(dir(iris))
<class 'sklearn.utils.Bunch'>
['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['IRIS'] = pd.DataFrame(data=iris.target)
df.shape
(150, 5)
df.head()

sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)IRIS
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20

最初の5個のデータだと0しかないので、ランダムサンプリングしてみると以下のようになります。

df.sample(frac=1, random_state=0).reset_index().head()

indexsepal length (cm)sepal width (cm)petal length (cm)petal width (cm)IRIS
01145.82.85.12.42
1626.02.24.01.01
2335.54.21.40.20
31077.32.96.31.82
475.03.41.50.20

各特徴量は以下の通りです。日本語に訳しましたが、あまりぴんと来ませんね。

英語名日本名
sepal lengthがく片の長さ
sepal widthがく片の幅
petal length花びらの長さ
petal width花びらの幅

また、IRISというターゲットの値は0,1,2となっており、それらはiris.target_namesで確認する事ができます。

iris.target_names
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

このリストのインデックスと対応しており、表にすると以下の様になります。

indexIRIS
0setosa
1versicolor
2virginica

糖尿病患者のデータ

  • target: 基準時から糖尿病の状態
  • 回帰問題
from sklearn.datasets import load_diabetes

diabetes = load_diabetes()
print(type(diabetes))
print(dir(diabetes))
print(diabetes.feature_names)
print(diabetes.data.shape)
<class 'sklearn.utils.Bunch'>
['DESCR', 'data', 'data_filename', 'feature_names', 'target', 'target_filename']
['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
(442, 10)
df = pd.DataFrame(data=diabetes.data, columns=diabetes.feature_names)
df['QM'] = diabetes.target # QM : quantitative measure
df.head()

agesexbmibps1s2s3s4s5s6QM
00.0380760.0506800.0616960.021872-0.044223-0.034821-0.043401-0.0025920.019908-0.017646151.0
1-0.001882-0.044642-0.051474-0.026328-0.008449-0.0191630.074412-0.039493-0.068330-0.09220475.0
20.0852990.0506800.044451-0.005671-0.045599-0.034194-0.032356-0.0025920.002864-0.025930141.0
3-0.089063-0.044642-0.011595-0.0366560.0121910.024991-0.0360380.0343090.022692-0.009362206.0
40.005383-0.044642-0.0363850.0218720.0039350.0155960.008142-0.002592-0.031991-0.046641135.0

手書きデータ

  • target:0~9までの数字
  • 分類問題

データはdigits.imagesdigits.dataの中に入っていますが、imagesは二次元配列でdataは8x8の一次元配列で格納されています。

from sklearn.datasets import load_digits

digits = load_digits()

print(type(digits))
print(dir(digits))
print(digits.data.shape)
print(digits.images.shape)
print(digits.target_names)
<class 'sklearn.utils.Bunch'>
['DESCR', 'data', 'images', 'target', 'target_names']
(1797, 64)
(1797, 8, 8)
[0 1 2 3 4 5 6 7 8 9]

一番最初に格納されているデータは以下の様になっています。

print(digits.images[0])
[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]

digits.images[0]を画像化してみます。

plt.imshow(digits.images[0], cmap='gray')
plt.grid(False)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x125974fd0>

何となく0に見えますね。色合いを変えてみます。

plt.imshow(digits.images[0])
plt.grid(False)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x12594af28>

グレースケールより見やすいでしょうか?変わらないですかね・・・ もちろん、これらのデータに対して、正解のデータが与えられています。

print(digits.target[0])
0

生理学的データと運動能力のデータ

  • target: 生理学的データ(体重、ウェスト、脈拍) (日本語訳の不正確かもしれません)
  • 回帰問題

運動能力から体重やウェストなどの身体的特徴を求める問題

from sklearn.datasets import load_linnerud

linnerud = load_linnerud()

print(type(linnerud))
print(dir(linnerud))
<class 'sklearn.utils.Bunch'>
['DESCR', 'data', 'data_filename', 'feature_names', 'target', 'target_filename', 'target_names']
df1 = pd.DataFrame(data=linnerud.data, columns=linnerud.feature_names)
df2 = pd.DataFrame(data=linnerud.target, columns=linnerud.target_names)
df1.head()

ChinsSitupsJumps
05.0162.060.0
12.0110.060.0
212.0101.0101.0
312.0105.037.0
413.0155.058.0
df2.head()

WeightWaistPulse
0191.036.050.0
1189.037.052.0
2193.038.058.0
3162.035.062.0
4189.035.046.0

ワインのデータ

  • target: ワインの種類
  • 分類問題
from sklearn.datasets import load_wine

wine = load_wine()

print(type(wine))
print(dir(wine))
print(wine.feature_names)

print(wine.target)
print(wine.target_names)
<class 'sklearn.utils.Bunch'>
['DESCR', 'data', 'feature_names', 'target', 'target_names']
['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline']
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
['class_0' 'class_1' 'class_2']

pandasで読み込んでみます。ターゲットの名前をWINEとして、wine.targetを追加します。

df = pd.DataFrame(data=wine.data, columns=wine.feature_names)
df['WINE'] = pd.DataFrame(data=wine.target)
df.head()

alcoholmalic_acidashalcalinity_of_ashmagnesiumtotal_phenolsflavanoidsnonflavanoid_phenolsproanthocyaninscolor_intensityhueod280/od315_of_diluted_winesprolineWINE
014.231.712.4315.6127.02.803.060.282.295.641.043.921065.00
113.201.782.1411.2100.02.652.760.261.284.381.053.401050.00
213.162.362.6718.6101.02.803.240.302.815.681.033.171185.00
314.371.952.5016.8113.03.853.490.242.187.800.863.451480.00
413.242.592.8721.0118.02.802.690.391.824.321.042.93735.00

先頭から5個のサンプリングだとWINEの列がすべて0になってしまったので、ランダムサンプリングしてみます。

df.sample(frac=1, random_state=0).reset_index().head()

indexalcoholmalic_acidashalcalinity_of_ashmagnesiumtotal_phenolsflavanoidsnonflavanoid_phenolsproanthocyaninscolor_intensityhueod280/od315_of_diluted_winesprolineWINE
05413.741.672.2516.4118.02.602.900.211.625.850.923.201060.00
115112.792.672.4822.0112.01.481.360.241.2610.800.481.47480.02
26312.371.132.1619.087.03.503.100.191.874.451.222.87420.01
35513.561.732.4620.5116.02.962.780.202.456.250.983.031120.00
412313.055.802.1321.586.02.622.650.302.012.600.733.10380.01

乳がんのデータ

  • target: がんの良性/悪性
  • 分類問題
from sklearn.datasets import load_breast_cancer

bc = load_breast_cancer()

print(type(bc))
print(dir(bc))
print(bc.feature_names)
print(bc.target_names)
<class 'sklearn.utils.Bunch'>
['DESCR', 'data', 'feature_names', 'filename', 'target', 'target_names']
['mean radius' 'mean texture' 'mean perimeter' 'mean area'
 'mean smoothness' 'mean compactness' 'mean concavity'
 'mean concave points' 'mean symmetry' 'mean fractal dimension'
 'radius error' 'texture error' 'perimeter error' 'area error'
 'smoothness error' 'compactness error' 'concavity error'
 'concave points error' 'symmetry error' 'fractal dimension error'
 'worst radius' 'worst texture' 'worst perimeter' 'worst area'
 'worst smoothness' 'worst compactness' 'worst concavity'
 'worst concave points' 'worst symmetry' 'worst fractal dimension']
['malignant' 'benign']

属性がかなり多いです。悪性が良性かの分類問題です。pandasで読み込んでみます。

df = pd.DataFrame(data=bc.data, columns=bc.feature_names)
df['MorB'] = pd.DataFrame(data=bc.target) # MorB means maligant or benign
df.head()

mean radiusmean texturemean perimetermean areamean smoothnessmean compactnessmean concavitymean concave pointsmean symmetrymean fractal dimension...worst textureworst perimeterworst areaworst smoothnessworst compactnessworst concavityworst concave pointsworst symmetryworst fractal dimensionMorB
017.9910.38122.801001.00.118400.277600.30010.147100.24190.07871...17.33184.602019.00.16220.66560.71190.26540.46010.118900
120.5717.77132.901326.00.084740.078640.08690.070170.18120.05667...23.41158.801956.00.12380.18660.24160.18600.27500.089020
219.6921.25130.001203.00.109600.159900.19740.127900.20690.05999...25.53152.501709.00.14440.42450.45040.24300.36130.087580
311.4220.3877.58386.10.142500.283900.24140.105200.25970.09744...26.5098.87567.70.20980.86630.68690.25750.66380.173000
420.2914.34135.101297.00.100300.132800.19800.104300.18090.05883...16.67152.201575.00.13740.20500.40000.16250.23640.076780

5 rows × 31 columns

ランダムサンプリングしてみます。

df.sample(frac=1, random_state=0).reset_index().head()

indexmean radiusmean texturemean perimetermean areamean smoothnessmean compactnessmean concavitymean concave pointsmean symmetry...worst textureworst perimeterworst areaworst smoothnessworst compactnessworst concavityworst concave pointsworst symmetryworst fractal dimensionMorB
051213.4020.5288.64556.70.110600.146900.144500.081720.2116...29.66113.30844.40.157400.385600.510600.205100.35850.110900
145713.2125.2584.10537.90.087910.052050.027720.020680.1619...34.2391.29632.90.128900.106300.139000.060050.24440.067881
243914.0215.6689.59606.50.079660.055810.020870.026520.1589...19.3196.53688.90.103400.101700.062600.082160.21360.067101
329814.2618.1791.22633.10.065760.052200.024750.013740.1635...25.26105.80819.70.094450.216700.156500.075300.26360.076761
43713.0318.4282.61523.80.089830.037660.025620.029230.1467...22.8184.46545.90.097010.046190.048330.050130.19870.061691

5 rows × 32 columns

良性の陰性の結果がMorBに見て取れます。

参考資料