6.5 ハイパーパラメータの検証

6.5.3 ハイパーパラメータ最適化の実装

ch06/hyperparameter_optimization.pyを参考にしながら、ハイパーパラメータの最適化を行ってみます。

課題を定義します。

from ivory.core.trainer import sequential
from ivory.datasets.mnist import load_dataset

data = load_dataset(train_only=True)
data.length = 500
data.shuffle()
data.split((8, 2))
data.batch_size = 100
data.epochs = 50
epoch_data = {"train": data[0, :], "val": data[1, :]}

net = [
    ("input", 784),
    (6, "affine", 100, "relu"),
    ("affine", 10, "softmax_cross_entropy"),
]

trainer = sequential(net, metrics=["accuracy"])
trainer = trainer.fit(data, epoch_data=epoch_data)
trainer

[1] 2019-06-12 17:37:44 (2.08s) python3 (2.08s)

Trainer(inputs=[(784,), ()], optimizer='SGD', metrics=['accuracy'])

ハイパーパラメータのランダム探索を行うジェネレータを定義します。以下ではハイパーパラメータの探索範囲を「ゼロから作るDeep Learning」から少し変更しています。

import numpy as np

def random_search():
    while True:
        weight_decay = 10 ** np.random.uniform(-10, -2)
        learning_rate = 10 ** np.random.uniform(-4, -1)
        trainer.optimizer.learning_rate = learning_rate
        trainer.init(weight_decay=weight_decay)
        df = trainer.to_frame()
        columns = list(df.columns)
        df["wd"] = weight_decay
        df["lr"] = learning_rate
        yield df[["wd", "lr"] + columns]

[2] 2019-06-12 17:37:46 (31.3ms) python3 (2.11s)

試してみます。

searcher = random_search()
print(next(searcher).tail(4))
print(next(searcher).tail(4))
print(next(searcher).tail(4))

[3] 2019-06-12 17:37:46 (3.31s) python3 (5.42s)

          wd        lr  epoch   data  accuracy
98   0.00053  0.019038  49     train  0.985   
99   0.00053  0.019038  49     val    0.790   
100  0.00053  0.019038  50     train  0.985   
101  0.00053  0.019038  50     val    0.790   
           wd        lr  epoch   data  accuracy
98   0.001066  0.011962  49     train  0.9200  
99   0.001066  0.011962  49     val    0.7100  
100  0.001066  0.011962  50     train  0.9225  
101  0.001066  0.011962  50     val    0.7200  
           wd        lr  epoch   data  accuracy
98   0.003083  0.003186  49     train  0.4725  
99   0.003083  0.003186  49     val    0.3400  
100  0.003083  0.003186  50     train  0.4725  
101  0.003083  0.003186  50     val    0.3500

100回のトライアルを行います。

import pandas as pd

df = pd.concat([df for _, df in zip(range(100), searcher)])
len(df)

[4] 2019-06-12 17:37:49 (1min54s) python3 (1min60s)

10200

ベスト5を可視化してみます。

acc = df.query("data == 'val'").groupby(["wd", "lr"])["accuracy"].max()
best = acc.sort_values(ascending=False).to_frame().reset_index()
best.index.name = "best"
best = best.reset_index()
best["best"] += 1
best[:5]

[5] 2019-06-12 17:39:44 (31.2ms) python3 (1min60s)

best wd lr accuracy
0 1 1.134661e-05 0.094884 0.86
1 2 1.448067e-05 0.091481 0.86
2 3 5.933917e-09 0.076393 0.86
3 4 3.368884e-05 0.082893 0.85
4 5 2.266553e-08 0.052939 0.85
import altair as alt
df_best = pd.merge(df, best[["best", "wd", "lr"]][:5])
alt.Chart(df_best).mark_line().encode(
    x="epoch", y="accuracy", color="data", column="best"
).properties(width=100, height=120)

[6] 2019-06-12 17:39:44 (93.7ms) python3 (1min60s)

x = alt.X("wd", scale=alt.Scale(type="log"))
y = alt.Y("lr", scale=alt.Scale(type="log"))
alt.Chart(best).mark_point().encode(
    x=x, y=y, color="accuracy", size="accuracy"
).properties(width=200, height=200)

[7] 2019-06-12 17:39:44 (46.9ms) python3 (1min60s)