2.4 カウントベースの手法の改善

「ゼロから作るDeep Learning ❷」で用意されている関数をそのまま使います。

2.4.1 相互情報量

import pandas as pd

from ivory.common.util import create_co_matrix, ppmi, preprocess

text = "You say goodbye and I say hello."
corpus, word_to_id, id_to_word = preprocess(text)
vocab_size = len(word_to_id)
C = create_co_matrix(corpus, vocab_size)
W = ppmi(C)
v = list(id_to_word.values())
df = pd.DataFrame(W, index=v, columns=v)
df

[1] 2019-06-12 17:40:43 (203ms) python3 (203ms)

you say goodbye and i hello .
you 0.000000 1.807355 0.000000 0.000000 0.000000 0.000000 0.000000
say 1.807355 0.000000 0.807355 0.000000 0.807355 0.807355 0.000000
goodbye 0.000000 0.807355 0.000000 1.807355 0.000000 0.000000 0.000000
and 0.000000 0.000000 1.807355 0.000000 1.807355 0.000000 0.000000
i 0.000000 0.807355 0.000000 1.807355 0.000000 0.000000 0.000000
hello 0.000000 0.807355 0.000000 0.000000 0.000000 0.000000 2.807355
. 0.000000 0.000000 0.000000 0.000000 0.000000 2.807355 0.000000

2.4.3 SVDによる次元削減

import numpy as np

U, S, V = np.linalg.svd(W)
print(C[0])
print(W[0])
print(U[0])

[2] 2019-06-12 17:40:43 (31.2ms) python3 (234ms)

[0 1 0 0 0 0 0]
[0.        1.8073549 0.        0.        0.        0.        0.       ]
[ 0.0000000e+00  3.4094876e-01 -1.2051624e-01 -3.8857806e-16
 -1.1102230e-16 -9.3232495e-01 -2.4257469e-17]

プロットしてみます。

import matplotlib.pyplot as plt

for word, word_id in word_to_id.items():
    plt.annotate(word, (U[word_id, 0], U[word_id, 1]))
plt.scatter(U[:, 0], U[:, 1], alpha=0.5)
plt.show()

[4] 2019-06-12 17:40:43 (141ms) python3 (500ms)

image/png

2.4.4 PTBデータセット

from ivory.utils.repository import import_module

ptb = import_module("scratch2/dataset/ptb")
corpus, word_to_id, id_to_word = ptb.load_data("train")
len(corpus)

[5] 2019-06-12 17:40:44 (31.3ms) python3 (531ms)

929589

2.4.5 PTBデータセットでの評価

from sklearn.utils.extmath import randomized_svd
from ivory.common.util import most_similar

window_size = 2
wordvec_size = 100
vocab_size = len(word_to_id)

print("counting  co-occurrence ...")
C = create_co_matrix(corpus, vocab_size, window_size)
print("calculating PPMI ...")
W = ppmi(C, verbose=True)

print("calculating SVD ...")
U, S, V = randomized_svd(W, n_components=wordvec_size, n_iter=5, random_state=None)

word_vecs = U[:, :wordvec_size]

querys = ["you", "year", "car", "toyota"]
for query in querys:
    most_similar(query, word_to_id, id_to_word, word_vecs, top=5)

[6] 2019-06-12 17:40:44 (5min29s) python3 (5min29s)

counting  co-occurrence ...
calculating PPMI ...
C:\Users\daizu\Documents\GitHub\ivory\ivory\common\util.py:115: RuntimeWarning: overflow encountered in long_scalars
  pmi = np.log2(C[i, j] * N / (S[j] * S[i]) + eps)
C:\Users\daizu\Documents\GitHub\ivory\ivory\common\util.py:115: RuntimeWarning: invalid value encountered in log2
  pmi = np.log2(C[i, j] * N / (S[j] * S[i]) + eps)
1.0% done
2.0% done
3.0% done
4.0% done
5.0% done
6.0% done
7.0% done
8.0% done
9.0% done
10.0% done
11.0% done
12.0% done
13.0% done
14.0% done
15.0% done
16.0% done
17.0% done
18.0% done
19.0% done
20.0% done
21.0% done
22.0% done
23.0% done
24.0% done
25.0% done
26.0% done
27.0% done
28.0% done
29.0% done
30.0% done
31.0% done
32.0% done
33.0% done
34.0% done
35.0% done
36.0% done
37.0% done
38.0% done
39.0% done
40.0% done
41.0% done
42.0% done
43.0% done
44.0% done
45.0% done
46.0% done
47.0% done
48.0% done
49.0% done
50.0% done
51.0% done
52.0% done
53.0% done
54.0% done
55.0% done
56.0% done
57.0% done
58.0% done
59.0% done
60.0% done
61.0% done
62.0% done
63.0% done
64.0% done
65.0% done
66.0% done
67.0% done
68.0% done
69.0% done
70.0% done
71.0% done
72.0% done
73.0% done
74.0% done
75.0% done
76.0% done
77.0% done
78.0% done
79.0% done
80.0% done
81.0% done
82.0% done
83.0% done
84.0% done
85.0% done
86.0% done
87.0% done
88.0% done
89.0% done
90.0% done
91.0% done
92.0% done
93.0% done
94.0% done
95.0% done
96.0% done
97.0% done
98.0% done
99.0% done
100.0% done
calculating SVD ...

[query] you
 i: 0.6961897611618042
 we: 0.6631303429603577
 anybody: 0.6072962284088135
 do: 0.5863798260688782
 'll: 0.5081530213356018

[query] year
 month: 0.6787183880805969
 quarter: 0.6236955523490906
 earlier: 0.6146764159202576
 next: 0.6078804731369019
 last: 0.5858464241027832

[query] car
 luxury: 0.5490508675575256
 auto: 0.5488603711128235
 cars: 0.537540853023529
 midsized: 0.459559828042984
 motor: 0.4571341276168823

[query] toyota
 motor: 0.7791914939880371
 honda: 0.6706029176712036
 nissan: 0.657393753528595
 motors: 0.6424781680107117
 lexus: 0.5998319983482361