3.1 Negative Sampling

「ゼロから作るDeep Learning ❷」で4章で導入されるNegative Samplingを実装します。元になる関数は、「ゼロから作るDeep Learning ❷」をほぼ踏襲し、IvoryライブラリのDatasetのサブクラスContextDatasetとして使用できるようにします。

実験用のコーパスを準備します。

corpus = [0, 1, 2, 3, 4, 1, 2, 3, 2]

[1] 2019-06-12 20:09:50 (15.5ms) python3 (15.5ms)

コンテキスト用データセットを作成します。

from ivory.common.dataset import ContextDataset

data = ContextDataset(corpus)
data.batch_size = 2
data

[2] 2019-06-12 20:09:50 (194ms) python3 (209ms)

ContextDataset(batch_size=2, epochs=1, len=3, column=0, size=(7,))

ウィンドウサイズはデフォルトで1です。

data.window_size

[3] 2019-06-12 20:09:50 (15.6ms) python3 (225ms)

1

通常のDatasetと同様に動作します。コンテキストとターゲットを返します。

data[0]

[4] 2019-06-12 20:09:50 (12.0ms) python3 (237ms)

(array([[0, 2],
        [1, 3]]), array([1, 2]))

ウィンドウサイズは、後から変更することもできます。

data.set_window_size(2)
data

[5] 2019-06-12 20:09:50 (12.6ms) python3 (250ms)

ContextDataset(batch_size=2, epochs=1, len=2, column=0, size=(5,))

データの長さが変化しました。データを取得します。

data[0]

[6] 2019-06-12 20:09:50 (21.0ms) python3 (271ms)

(array([[0, 1, 3, 4],
        [1, 2, 4, 1]]), array([2, 3]))

Nagative Samplingを行うためには、negative_sampling_sizeを指定します。

data.negative_sample_size = 2
data[0]

[7] 2019-06-12 20:09:50 (15.6ms) python3 (286ms)

(array([[0, 1, 3, 4],
        [1, 2, 4, 1]]), array([2, 3]), array([3, 4]), array([1, 1]))

第3要素以降が負例になります。デフォルトでは、ターゲットとの重複を許しません。

data.negative_sample_size = 4
data[0]

[8] 2019-06-12 20:09:50 (9.00ms) python3 (295ms)

(array([[0, 1, 3, 4],
        [1, 2, 4, 1]]),
 array([2, 3]),
 array([4, 4]),
 array([3, 2]),
 array([1, 1]),
 array([0, 0]))

replace属性をTrueにすることで、重複を許します。すなわち、ターゲットが負例として現れることを(速度を優先した結果として)許容します。

data.replace = True
data.negative_sample_size = 10
data[0]

[9] 2019-06-12 20:09:50 (11.0ms) python3 (306ms)

(array([[0, 1, 3, 4],
        [1, 2, 4, 1]]),
 array([2, 3]),
 array([1, 1], dtype=int64),
 array([0, 2], dtype=int64),
 array([1, 3], dtype=int64),
 array([2, 1], dtype=int64),
 array([0, 2], dtype=int64),
 array([3, 3], dtype=int64),
 array([4, 2], dtype=int64),
 array([3, 3], dtype=int64),
 array([2, 1], dtype=int64),
 array([2, 2], dtype=int64))