5.7 RNNLMの学習と評価

PTBデータセットを読み出します。

from ivory.common.dataset import TimeDataset
from ivory.utils.repository import import_module

ptb = import_module("scratch2/dataset/ptb")
corpus, word_to_id, id_to_word = ptb.load_data("train")
corpus = corpus[:1000]
vocab_size = int(max(corpus) + 1)
x, t = corpus[:-1], corpus[1:]
data = TimeDataset((x, t), time_size=5, batch_size=10)
data.epochs = 100
data

[1] 2019-06-19 12:26:46 (39.0ms) python3 (625ms)

TimeDataset(time_size=5, batch_size=10, epochs=100, len=19, column=0, size=(999,))

ハイパーパラメータの設定を行います。

wordvec_size = 100
hidden_size = 100

[2] 2019-06-19 12:26:46 (7.00ms) python3 (632ms)

モデルを作成します。

from ivory.core.trainer import sequential

net = [
    ("input", vocab_size),
    ("embedding", wordvec_size),
    ("rnn", hidden_size),
    ("affine", vocab_size, "softmax_cross_entropy"),
]
trainer = sequential(net, optimizer="sgd", metrics=["loss"])
trainer.optimizer.learning_rate = 0.1
model = trainer.model
for layer in model.layers:
    print(layer)

[3] 2019-06-19 12:26:46 (27.0ms) python3 (659ms)

<Embedding('Embedding.1', (418, 100)) at 0x247e876e908>
<RNN('RNN.2', (100, 100)) at 0x247e87c87f0>
<Affine('Affine.1', (100, 418)) at 0x247e87c8978>
<SoftmaxCrossEntropy('SoftmaxCrossEntropy.3', (418,)) at 0x247e87c8b38>

重みの初期値を設定します。

from ivory.common.context import np

model.init(std="xavier")
for p in model.weights:
    if p.name != "b":
        std1, std2 = f"{p.d.std():.03f}", f"{np.sqrt(1/p.d.shape[0]):.03f}"
        print(p.layer.name, p.name, std1, std2)

[4] 2019-06-19 12:26:46 (29.1ms) python3 (688ms)

Embedding.1 W 0.049 0.049
RNN.2 W 0.100 0.100
RNN.2 U 0.101 0.100
Affine.1 W 0.100 0.100

モデルに代入し、パープレキシティを計算してみます。

trainer.set_data(*data[0])
model.forward()
print(model.perplexity)

[5] 2019-06-19 12:26:46 (11.0ms) python3 (699ms)

422.05288041862997

訓練を実施します。

trainer.fit(data)
df = trainer.to_frame()
df["epoch"] = df.iteration // len(data)
df = df.groupby("epoch").mean().reset_index()
df["ppl"] = np.exp(df.loss)
df.tail()

[6] 2019-06-19 12:26:46 (5.18s) python3 (5.88s)

epoch iteration loss ppl
95 95 1814 1.719615 5.582377
96 96 1833 1.594855 4.927614
97 97 1852 1.583730 4.873099
98 98 1871 1.560221 4.759871
99 99 1890 1.460754 4.309210

可視化します

import altair as alt

alt.Chart(df).mark_line().encode(x="epoch", y="ppl").properties(width=300, height=200)

[7] 2019-06-19 12:26:51 (41.0ms) python3 (5.92s)