5.7 誤差逆伝搬法の実装

「ゼロから作るDeep Learning」 5章7節の内容に合わせて、Ivoryライブラリを使って実際に学習を行ってみます。

5.7.1 訓練データの取得

MNIST手書き文字の訓練データは、Datasetとして用意されています。

from ivory.datasets.mnist import load_dataset

data_train, data_test = load_dataset()
data_train.random = True
data_train

[1] 2019-06-12 17:18:26 (1.89s) python3 (1.89s)

mnist_train(batch_size=1, epochs=1, len=60000, column=0, size=(60000,))

5.7.2 誤差逆伝搬法に対応したニューラルネットワークの実装

from ivory.core.model import sequential

net = [
    ("input", 784),
    ("affine", 50, "relu"),
    ("affine", 10, "softmax_cross_entropy"),
]
model = sequential(net)
model

[2] 2019-06-12 17:18:28 (31.2ms) python3 (1.92s)

<ivory.core.model.Model at 0x25f32356b70>

5.7.3 誤差逆伝搬法を使った学習

最後に学習の実装です。「ゼロから作るDeep Learning」5章7節のコードを参考にします。

iters_num = 10000
train_size = data_train.data[0].shape[0]
data_train.batch_size = 100
iter_per_epoch = train_size // data_train.batch_size
data_train.epochs = -1
learning_rate = 0.1

for i, (x, t) in zip(range(iters_num), data_train):
    model.set_data(x, t)
    model.forward()
    model.backward()

    for variable in model.weight_variables:
        variable.data -= learning_rate * variable.grad

    if i % iter_per_epoch == 0:
        x, t = data_train[:]
        model.set_data(x, t)
        model.forward()
        train_acc = model.accuracy
        x, t = data_test[:]
        model.set_data(x, t)
        model.forward()
        test_acc = model.accuracy
        print(f"{train_acc:.3f}", f"{test_acc:.3f}")

[3] 2019-06-12 17:18:28 (13.2s) python3 (15.1s)

0.131 0.128
0.919 0.920
0.936 0.937
0.948 0.946
0.955 0.952
0.958 0.954
0.963 0.959
0.962 0.957
0.969 0.965
0.970 0.964
0.972 0.965
0.974 0.966
0.975 0.966
0.976 0.965
0.978 0.969
0.979 0.968
0.981 0.971

「ゼロから作るDeep Learning」レポジトリのコードを実行してみます。

from ivory.utils.repository import run

run("scratch/ch05/train_neuralnet.py")

[4] 2019-06-12 17:18:41 (25.6s) python3 (40.7s)

0.10726666666666666 0.1018
0.9051833333333333 0.9088
0.9274833333333333 0.9281
0.9384333333333333 0.9378
0.9466833333333333 0.9437
0.9536 0.9516
0.9579833333333333 0.9544
0.9616 0.9578
0.9643833333333334 0.9605
0.9674 0.9624
0.9702833333333334 0.9646
0.9710166666666666 0.9653
0.9728666666666667 0.9673
0.9728833333333333 0.9676
0.9763 0.9692
0.9778833333333333 0.9701
0.97925 0.9703

実行速度の違いは、Ivoryライブラリでのパラメータのビット精度32ビットであるためです。