BERT-embedding
A simple wrapper class for extracting features(embedding) and comparing them using BERT
How to Use
Installation
git clone https://github.com/seriousmac/BERT-embedding.git
cd BERT-embedding
pip install -r requirements.txt
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip -d bert/
Run a test
python bert_embedding.py
Major functions
-
bert.init()
#์ด๊ธฐํ -
bert.extract(sentence)
#๋ชจ๋ ๊ฒฐ๊ณผ ์ถ์ถ, ์๋์ input and output์์ ์ ์ถ๋ ฅ ๊ตฌ์กฐ ์์ธํ ์ค๋ช -
bert.extracts(sentences)
#string list๋ฅผ ์ ๋ ฅ๋ฐ์ -
bert.extract_v1(sentence)
#embedding ๊ฐ๋ง ์ถ์ถ -
bert.extracts_v1(sentences)
-
bert.cal_dif_cls(result1, result2)
#extract ํน์ extracts์ ์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ์ด์ฉํ์ฌ distance ๊ณ์ฐ -
bert.cal_dif_cls_layer(result1, result2, layer_num)
#์์ ํจ์์์ ํน์ layer์ ๋ํด์๋ง ๊ณ์ฐ
Input and output
- bert.extracts(sentences)
- input: list of string
- output: list of dict
- 'features': ์
๋ ฅํ ๋ฌธ์ฅ ๋ด ํ ํฐ ๊ฐฏ์ ๋งํผ์ list
- 'token': ํ ํฐ ๊ฐ
- 'layers': list of layer dict
- 'index': layer ๋ฒํธ
- 'values': 768๊ธธ์ด์ float๊ฐ list =extracting features(embedding)
- 'features': ์
๋ ฅํ ๋ฌธ์ฅ ๋ด ํ ํฐ ๊ฐฏ์ ๋งํผ์ list
Examples
Example 1 - ํ ๋ฌธ์ฅ์์ embedding ์ถ์ถํ๊ธฐ
from bert_embedding import BERT
bert = BERT()
bert.init()
sentence = "[OBS ๋
ํนํ ์ฐ์๋ด์ค ์กฐ์ฐ์ ๊ธฐ์] ๊ฐ์ ๊ฒธ ๋ฐฐ์ฐ ์์ง๊ฐ '๊ตญ๋ฏผ' ํ์ดํ์ ๊ฑฐ๋จธ์ฅ ์คํ๋ก ๊ผฝํ๋ค."
result = bert.extract(sentence)
Example 2 - ์ฌ๋ฌ ๋ฌธ์ฅ์์ embedding ์ถ์ถํ๊ธฐ
from bert_embedding import BERT
bert = BERT()
bert.init()
sentences = ['โ์ธ๊ณ์ ๊ณต์ฅโ์ผ๋ก ๋ง๋ํ ๋ฌ๋ฌ๋ฅผ ์ธ์ด๋ด์ผ๋ฉฐ ๊ฒฝ์ ๋ ฅ์ ํค์ ๋ ์ค๊ตญ์ ์ข์ ์์ ๋ ์ค๋๊ฐ์ง ์์ ๋ฏ>ํ๋ค.',
'์๋ณธ ์ ์ถ๊ณผ ์๋น์ค ์์ง ์ ์ ํญ์ด ์ปค์ง๋ฉฐ ๊ฒฝ์์์ง ์ ์๋ฅผ ํฅํด ๋น ๋ฅด๊ฒ ๋ค๊ฐ๊ฐ๊ณ ์์ด์๋ค.',
"[OBS ๋
ํนํ ์ฐ์๋ด์ค ์กฐ์ฐ์ ๊ธฐ์] ๊ฐ์ ๊ฒธ ๋ฐฐ์ฐ ์์ง๊ฐ '๊ตญ๋ฏผ' ํ์ดํ์ ๊ฑฐ๋จธ์ฅ ์คํ๋ก ๊ผฝํ๋ค.",
"OBS '๋
ํนํ ์ฐ์๋ด์ค'(๊ธฐํยท์ฐ์ถยท๊ฐ์ ์ค๊ฒฝ์ฒ , ์๊ฐ ๋ฐ์๊ฒฝยท๊นํ์ )๊ฐ '๊ตญ๋ฏผ ์ ๋๋กฌ'์ ์ผ์ผํจ ์ฒซ์ฌ๋์ ์์ด์ฝ >๊น์ฐ์, ์์ง, ์คํ์ ๊ทผํฉ์ ์ดํด๋ดค๋ค."]
results = bert.extracts(sentences)
Example 3 - CLS๋ง ์ด์ฉํด distance๊ฐ ๊ฐ์ฅ ๊ฐ๊น์ด ๋ฌธ์ฅ ์ฐพ๊ธฐ
from bert_embedding import BERT
bert = BERT()
bert.init()
sentences = ['โ์ธ๊ณ์ ๊ณต์ฅโ์ผ๋ก ๋ง๋ํ ๋ฌ๋ฌ๋ฅผ ์ธ์ด๋ด์ผ๋ฉฐ ๊ฒฝ์ ๋ ฅ์ ํค์ ๋ ์ค๊ตญ์ ์ข์ ์์ ๋ ์ค๋๊ฐ์ง ์์ ๋ฏ>ํ๋ค.',
'์๋ณธ ์ ์ถ๊ณผ ์๋น์ค ์์ง ์ ์ ํญ์ด ์ปค์ง๋ฉฐ ๊ฒฝ์์์ง ์ ์๋ฅผ ํฅํด ๋น ๋ฅด๊ฒ ๋ค๊ฐ๊ฐ๊ณ ์์ด์๋ค.',
'[OBS ๋
ํนํ ์ฐ์๋ด์ค ์กฐ์ฐ์ ๊ธฐ์] ๊ฐ์ ๊ฒธ ๋ฐฐ์ฐ ์์ง๊ฐ ๊ตญ๋ฏผ ํ์ดํ์ ๊ฑฐ๋จธ์ฅ ์คํ๋ก ๊ผฝํ๋ค.',
'OBS ๋
ํนํ ์ฐ์๋ด์ค(๊ธฐํยท์ฐ์ถยท๊ฐ์ ์ค๊ฒฝ์ฒ , ์๊ฐ ๋ฐ์๊ฒฝยท๊นํ์ )๊ฐ ๊ตญ๋ฏผ ์ ๋๋กฌ์ ์ผ์ผํจ ์ฒซ์ฌ๋์ ์์ด์ฝ >๊น์ฐ์, ์์ง, ์คํ์ ๊ทผํฉ์ ์ดํด๋ดค๋ค.',
'์ค๋์ ๋ ์จ๊ฐ ์ข์ต๋๋ค. ๋ง์ง์ ์ฐพ์ ๊ฐ๋ณผ๊น์? ์์ด๋ค์ด ์ข์ํ๋๋ผ๊ตฌ์.',
'๋ณด์์ง์์๋ ๋ณด์์ ๋ง์๊ฒ ํ๋ฉด ๊ทธ๋ง์
๋๋ค.ใ
ใ
']
results = bert.extracts(sentences)
distances = []
for i in range(len(results)):
distance = []
for j in range(len(results)):
if i == j:
distance.append(99999)
else:
distance.append(bert.cal_dif_cls(results[i], results[j]))
distances.append(distance)
for idx in range(len(sentences)):
print(sentences[idx])
print(sentences[distances[idx].index(min(distances[idx]))])
print()
์ถ๋ ฅ ๊ฒฐ๊ณผ
โ์ธ๊ณ์ ๊ณต์ฅโ์ผ๋ก ๋ง๋ํ ๋ฌ๋ฌ๋ฅผ ์ธ์ด๋ด์ผ๋ฉฐ ๊ฒฝ์ ๋ ฅ์ ํค์ ๋ ์ค๊ตญ์ ์ข์ ์์ ๋ ์ค๋๊ฐ์ง ์์ ๋ฏ>ํ๋ค.
์๋ณธ ์ ์ถ๊ณผ ์๋น์ค ์์ง ์ ์ ํญ์ด ์ปค์ง๋ฉฐ ๊ฒฝ์์์ง ์ ์๋ฅผ ํฅํด ๋น ๋ฅด๊ฒ ๋ค๊ฐ๊ฐ๊ณ ์์ด์๋ค.
์๋ณธ ์ ์ถ๊ณผ ์๋น์ค ์์ง ์ ์ ํญ์ด ์ปค์ง๋ฉฐ ๊ฒฝ์์์ง ์ ์๋ฅผ ํฅํด ๋น ๋ฅด๊ฒ ๋ค๊ฐ๊ฐ๊ณ ์์ด์๋ค.
โ์ธ๊ณ์ ๊ณต์ฅโ์ผ๋ก ๋ง๋ํ ๋ฌ๋ฌ๋ฅผ ์ธ์ด๋ด์ผ๋ฉฐ ๊ฒฝ์ ๋ ฅ์ ํค์ ๋ ์ค๊ตญ์ ์ข์ ์์ ๋ ์ค๋๊ฐ์ง ์์ ๋ฏ>ํ๋ค.
[OBS ๋
ํนํ ์ฐ์๋ด์ค ์กฐ์ฐ์ ๊ธฐ์] ๊ฐ์ ๊ฒธ ๋ฐฐ์ฐ ์์ง๊ฐ ๊ตญ๋ฏผ ํ์ดํ์ ๊ฑฐ๋จธ์ฅ ์คํ๋ก ๊ผฝํ๋ค.
OBS ๋
ํนํ ์ฐ์๋ด์ค(๊ธฐํยท์ฐ์ถยท๊ฐ์ ์ค๊ฒฝ์ฒ , ์๊ฐ ๋ฐ์๊ฒฝยท๊นํ์ )๊ฐ ๊ตญ๋ฏผ ์ ๋๋กฌ์ ์ผ์ผํจ ์ฒซ์ฌ๋์ ์์ด์ฝ >๊น์ฐ์, ์์ง, ์คํ์ ๊ทผํฉ์ ์ดํด๋ดค๋ค.
OBS ๋
ํนํ ์ฐ์๋ด์ค(๊ธฐํยท์ฐ์ถยท๊ฐ์ ์ค๊ฒฝ์ฒ , ์๊ฐ ๋ฐ์๊ฒฝยท๊นํ์ )๊ฐ ๊ตญ๋ฏผ ์ ๋๋กฌ์ ์ผ์ผํจ ์ฒซ์ฌ๋์ ์์ด์ฝ >๊น์ฐ์, ์์ง, ์คํ์ ๊ทผํฉ์ ์ดํด๋ดค๋ค.
[OBS ๋
ํนํ ์ฐ์๋ด์ค ์กฐ์ฐ์ ๊ธฐ์] ๊ฐ์ ๊ฒธ ๋ฐฐ์ฐ ์์ง๊ฐ ๊ตญ๋ฏผ ํ์ดํ์ ๊ฑฐ๋จธ์ฅ ์คํ๋ก ๊ผฝํ๋ค.
์ค๋์ ๋ ์จ๊ฐ ์ข์ต๋๋ค. ๋ง์ง์ ์ฐพ์ ๊ฐ๋ณผ๊น์? ์์ด๋ค์ด ์ข์ํ๋๋ผ๊ตฌ์.
๋ณด์์ง์์๋ ๋ณด์์ ๋ง์๊ฒ ํ๋ฉด ๊ทธ๋ง์
๋๋ค.ใ
ใ
๋ณด์์ง์์๋ ๋ณด์์ ๋ง์๊ฒ ํ๋ฉด ๊ทธ๋ง์
๋๋ค.ใ
ใ
์ค๋์ ๋ ์จ๊ฐ ์ข์ต๋๋ค. ๋ง์ง์ ์ฐพ์ ๊ฐ๋ณผ๊น์? ์์ด๋ค์ด ์ข์ํ๋๋ผ๊ตฌ์.
Example 4 - ๋ฌธ์ฅ ๋ด ํน์ token์ embedding๋ง ๋น๊ตํ๊ธฐ
from bert_embedding import BERT
bert = BERT()
bert.init()
sentences = ["๋ง์น ํ๋ณด ์ปท์ ๋ฐฉ๋ถ์ผ ํ ์ด๋ฒ ์ด๋ฏธ์ง๋ ํด์ธ ๋ก์ผ์ดฌ์ ์ ์ดฌ์๋ ์ปท์ผ๋ก ํนํ, ์๋ก์ฐ ์ปฌ๋ฌ์ ๋ ํธ๋กํ ํดํธ์ ๊ธ๋ผ์ค๋ฅผ ์ฐฉ์ฉํ ์ฑ ์งํ์ฐจ๋ฅผ ์ด์ ํ๋ ์์ง์ ๋ชจ์ต์์ ๊ธฐ์กด์ ์ฒญ์ํ ๋ชจ์ต๊ณผ๋ ๋ค๋ฅธ ๋ํ์ ์ธ ๋ถ์๊ธฐ์ ํ์ธต ์ฑ์ํด์ง ๋ชจ์ต์ ๋ณด์ฌ์ฃผ๋ฉฐ ๊ทน ์ค ์บ๋ฆญํฐ์ ๋ํ ๊ธฐ๋๊ฐ์ ๋์๋ค.",
'์๋ณธ ์ ์ถ๊ณผ ์๋น์ค ์์ง ์ ์ ํญ์ด ์ปค์ง๋ฉฐ ๊ฒฝ์ ์์ง ์ ์๋ฅผ ํฅํด ๋น ๋ฅด๊ฒ ๋ค๊ฐ๊ฐ๊ณ ์์ด์๋ค.',
"[์กฐ์ฐ์ ๊ธฐ์] ๊ฐ์ ๊ฒธ ๋ฐฐ์ฐ ์์ง๊ฐ ๊ตญ๋ฏผ ํ์ดํ์ ๊ฑฐ๋จธ์ฅ ์คํ๋ก ๊ผฝํ๋ค."]
results = bert.extracts(sentences)
for i in range(len(results)):
for j in range(len(results)):
print(sentences[i])
print(sentences[j])
cal_dif_keyword(results[i], results[j], '์์ง')
To Do List
- Define class
- embedding ์ฝ๊ฒ ์ถ์ถํ๊ธฐ
- CLS๋ง์ ์ด์ฉํด ๋ฌธ์ฅ์ distance ๊ณ์ฐํ๊ธฐ
- ๋ฌธ์ฅ ๋ด ๋ชจ๋ token๋ค์ embedding์ ์ด์ฉํด distance ๊ณ์ฐํ๊ธฐ
- ๋ฌธ์ฅ ๋ด ํน์ token๋ง ๋น๊ตํ๊ธฐ (์: ๊ฒฝ์ ์ '์์ง'์ ์ฐ์์ '์์ง' ๊ฐ์ ์ฐจ์ด ํ์ธํ๊ธฐ)