1.7 トレーナー

「ゼロから作るDeep Learning」では、学習の効率化のためにTrainerクラスを導入しています。IvoryライブラリでもTrainerクラスを実装し、少ないコードで学習ができるようにしています。

データセットを用意します。

from ivory.datasets.mnist import load_dataset

data = load_dataset(train_only=True)
data.length = 1000  # データセットの大きさを制限します。
data.batch_size = 100
data.epochs = 20
data.random = True
data

[1] 2019-06-12 20:00:33 (2.00s) python3 (2.00s)

mnist_train(batch_size=100, epochs=20, len=10, column=0, size=(1000,))

Trainerクラスはレイヤ表現を引数にとるivory.common.trainerモジュールのsequential関数で作成できます。

from ivory.core.trainer import sequential

net = [("input", 784), ("affine", 50, "relu"), ("affine", 10, "softmax_cross_entropy")]
trainer = sequential(net)
print(trainer.model)
print(trainer.optimizer)

[2] 2019-06-12 20:00:35 (31.3ms) python3 (2.03s)

<ivory.core.model.Model object at 0x00000252D96F4DA0>
SGD(learning_rate=0.01, name='SGD')

Trainerインスタンスのレイヤパラメータを初期化するには、initメソッドを呼び出します。オプショナルのstdキーワード引数を指定すると、標準偏差を設定できます。

from ivory.common.context import np

W = trainer.model.weight_variables[0]
print(W.data.std(), np.sqrt(2 / W.shape[0]))
trainer.init(std=100)
print(W.data.std())
trainer.init(std="he")
print(W.data.std())

[3] 2019-06-12 20:00:35 (229ms) python3 (2.26s)

0.05050477 0.050507627227610534
100.32395
0.050630372

レイヤの状態変数は、名前を引数に渡すことで値を設定できます。キーワードを指定しなければ、初期値に戻ります。

wd = trainer.model.state_variables[0]
print(wd.data)
trainer.init(weight_decay=1)
print(wd.data)
trainer.init()
print(wd.data)

[4] 2019-06-12 20:00:36 (9.04ms) python3 (2.27s)

0
1
0

実際に学習してみます。initメソッドは自分自身を返すので、呼び出しをチェインできます。fitメソッドも同様に訓練データの設定をした後、自分自身を返します。

trainer = sequential(net, metrics=["accuracy"]).init(std=0.1)
trainer = trainer.fit(data, epoch_data=data[:])
trainer

[5] 2019-06-12 20:00:36 (14.0ms) python3 (2.28s)

Trainer(inputs=[(784,), ()], optimizer='SGD', metrics=['accuracy'])

実際の訓練はイタレータを作って行います。

it = iter(trainer)
print(next(it))
print(next(it))
print(next(it))

[6] 2019-06-12 20:00:36 (58.2ms) python3 (2.34s)

(0, 0.146)
(1, 0.178)
(2, 0.214)

to_frameメソッドは訓練を行った後に結果をデータフレームで返します。

df = trainer.to_frame()
df.tail()

[7] 2019-06-12 20:00:36 (315ms) python3 (2.66s)

epoch accuracy
16 16 0.619
17 17 0.633
18 18 0.649
19 19 0.653
20 20 0.674