1.6 データセット

Ivoryライブラリでは、学習するデータのセットをイテラブルとして実装します。おもちゃのデータを用意します。

import numpy as np

from ivory.common.dataset import Dataset

x_train = np.arange(0, 201).reshape(-1, 3)
t_train = np.arange(0, 201 // 3).reshape(-1, 1)
data = Dataset([x_train, t_train])
data

[1] 2019-08-30 08:17:01 (237ms) python3 (237ms)

Dataset(batch_size=1, epochs=1, len=67, column=0, size=(67,))

Datasetクラスのreprは、内部状態を表示します。バッチサイズを変えてみます。

data.batch_size = 4
data

[2] 2019-08-30 08:17:01 (24.5ms) python3 (262ms)

Dataset(batch_size=4, epochs=1, len=16, column=0, size=(67,))

データの長さ(=バッチの個数)が16に減りました。実際、

len(data)

[4] 2019-08-30 08:17:01 (15.6ms) python3 (309ms)

16

となります。Datasetクラスのインスタンスは通常のリストのようにインデクシングができます。タプルが返されますので、各々の変数に代入するにはアンパックします。

x, t = data[0]
x

[5] 2019-08-30 08:17:01 (15.7ms) python3 (324ms)

array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]])
t

[6] 2019-08-30 08:17:01 (15.6ms) python3 (340ms)

array([[0],
       [1],
       [2],
       [3]])

取り出したデータのshapeDatasetshapeと一致します。

data.shape, x.shape, t.shape

[7] 2019-08-30 08:17:01 (15.6ms) python3 (355ms)

(((4, 3), (4, 1)), (4, 3), (4, 1))

先ほどの例では、データは先頭から取り出されていました。random属性をTrueにするとランダムにデータを取り出せます。

data.random = True
x, t = data[0]
x

[8] 2019-08-30 08:17:01 (15.6ms) python3 (371ms)

array([[168, 169, 170],
       [ 72,  73,  74],
       [156, 157, 158],
       [ 60,  61,  62]])

通常は学習用のデータを「訓練データ」と「検証データ」のサブセットに分けます。(テストデータはまた別に用意するべきです。)split関数を使えばデータをサブセットに分割できます。以下の例では3対1の大きさで2分割しています。

data.split((3, 1))
data

[9] 2019-08-30 08:17:01 (15.7ms) python3 (387ms)

Dataset(batch_size=4, epochs=1, len=12, column=0, size=(50, 17))

size属性の要素数が変化したことが分かります。また、

len(data)

[10] 2019-08-30 08:17:01 (15.6ms) python3 (402ms)

12

となりました。どのサブセットからデータを取得するかは、column属性で指定します。

data.column = 1
data

[11] 2019-08-30 08:17:01 (15.6ms) python3 (418ms)

Dataset(batch_size=4, epochs=1, len=4, column=1, size=(50, 17))

取り出せるデータ数(lenの値)が変わりました。実際に取り出してみます。

x, t = data[0]
x

[12] 2019-08-30 08:17:01 (15.6ms) python3 (434ms)

array([[174, 175, 176],
       [180, 181, 182],
       [165, 166, 167],
       [195, 196, 197]])

多次元配列のように扱うことができます。第1要素がサブセット番号、第2要素がインデックスです。

x, t = data[0, 0]
print(t)
x, t = data[1, 0]
print(t)

[13] 2019-08-30 08:17:01 (15.6ms) python3 (449ms)

[[46]
 [14]
 [32]
 [41]]
[[66]
 [65]
 [65]
 [64]]

columnが1のとき、元データの後半部分からデータを取得していることが分かります。サブセット間でデータを混ぜるには、shuffle関数を使います。

data.shuffle()
data.batch_size = 15
x, t = data[0, 0]
print(t.reshape(-1))
x, t = data[1, 0]
print(t.reshape(-1))

[14] 2019-08-30 08:17:01 (15.7ms) python3 (465ms)

[54  8  5 24 37 24 49 63  0  7 45 43  8 35 11]
[25 48  1 25 34 50 25 53 50 53  1 14  1 44 39]

Datasetはforループで使うことができます。

data.batch_size = 4
for k, (x, t) in enumerate(data):
    print(f"#{k}", t.reshape(-1))

[15] 2019-08-30 08:17:01 (15.6ms) python3 (480ms)

#0 [50 64 39 34]
#1 [39 14 39 44]
#2 [53 39 39  6]
#3 [48 20 25 44]

epochsを指定してイタレーション回数をコントロールできます。epoch属性、index属性、iteration属性がDatasetのイタレーション状態を保持します。epoch属性は、エポックの区切り以外では-1となります。また、state属性はこれらをまとめてタプル値を返します。

data.epochs = 2
for _ in data:
    print(data.epoch, data.index, data.iteration)

[16] 2019-08-30 08:17:01 (15.6ms) python3 (496ms)

0 0 0
0 1 1
0 2 2
0 3 3
1 0 4
1 1 5
1 2 6
1 3 7

epochsに-1を入力すると無限にループできます。

data.epochs = -1
for k, _ in enumerate(zip(range(1234), data)):
    pass
print(k, data.state)

[17] 2019-08-30 08:17:01 (47.1ms) python3 (543ms)

1233 (308, 1, 1233)

イテレータとして使うこともできます。分かりやすくするため、もう一度おもちゃのデータを作成します。

x_train = np.arange(0, 201).reshape(-1, 3)
t_train = np.arange(0, 201 // 3).reshape(-1, 1)
data = Dataset([x_train, t_train])
data.split((2, 3, 4))
data.batch_size = 4
data

[18] 2019-08-30 08:17:01 (15.4ms) python3 (559ms)

Dataset(batch_size=4, epochs=1, len=3, column=0, size=(14, 22, 31))
it = iter(data)
x, t = next(it)
print(t.reshape(-1))
x, t = next(it)
print(t.reshape(-1))
data.column = 2
it = iter(data)
x, t = next(it)
print(t.reshape(-1))

[19] 2019-08-30 08:17:01 (15.6ms) python3 (574ms)

[0 1 2 3]
[4 5 6 7]
[36 37 38 39]

Datasetはスライス表記をサポートしています。

x, t = data[2:4]
print(len(x))
x, t = data[:]
print(len(x))

[20] 2019-08-30 08:17:01 (15.8ms) python3 (590ms)

8
31

データの一部だけを使いたいとき、length属性で大きさを制限できます。

data.length = 20
data

[21] 2019-08-30 08:17:01 (15.5ms) python3 (605ms)

Dataset(batch_size=4, epochs=1, len=2, column=2, size=(4, 6, 10))

スライス表記の結果を見てみます。

x, t = data[:]
len(x)

[22] 2019-08-30 08:17:01 (15.6ms) python3 (621ms)

10

lengthの値を-1にすると、元々の全データを使う状態に戻ります。

data.length = -1
print(data.size)
data

[23] 2019-08-30 08:17:01 (15.6ms) python3 (637ms)

(13, 20, 34)
Dataset(batch_size=4, epochs=1, len=8, column=2, size=(13, 20, 34))
x, t = data[:]
len(x)

[24] 2019-08-30 08:17:01 (15.6ms) python3 (652ms)

34