BERT로 RE task 다루기

  • Task : KLUE-RE

  • 담당자: 김보석

  • 최종수정일: 21-09-15

  • 본 자료는 가짜연구소 3기 KLUE 로 모델 평가하기 크루 활동으로 작성됨

01 init

본 노트북은 hugging face의 pretrained-bert로 관계추출하는 방법을 설명합니다.
Simple BERT Models for Relation Extraction 논문을 참고하였습니다.(참고 Relation Extraction 논문)

  • 관계추출 task란?
    입력으로 주어진 문장 내에서 2개의 entity에 대한 관계를 총 30가지의 관계 중 하나로 분류하는 것입니다.

    입력 
    - 문장 : 이날 보고회에는 권오봉 여수시장과 전문가 자문위원, 전남도와 여수시 관계공무원  20 명이 참석했다.
    - subject entity : 여수시
    - object entity : 권오봉
    
    출력 
    - lable : 10
    

필요 라이브러리 install

  • datasets : hugging face 의 datasets 라이브러리(관련 페이지 링크) 중 load_dataset 매서드를 사용하면 쉽게 데이터를 다운로드 받을 수 있습니다.

  • sklearn : 학습 모델 평가시 사용합니다.

  • transformers : hugging face(관련 페이지 링크) 에서 BERT 모델 등을 불러오기 위해 사용합니다.

!pip install datasets
!pip install sklearn
!pip install transformers
Collecting datasets
  Downloading datasets-1.12.1-py3-none-any.whl (270 kB)
     |████████████████████████████████| 270 kB 5.6 MB/s 
?25hCollecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
     |████████████████████████████████| 243 kB 33.6 MB/s 
?25hRequirement already satisfied: pyarrow!=4.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)
Collecting huggingface-hub<0.1.0,>=0.0.14
  Downloading huggingface_hub-0.0.17-py3-none-any.whl (52 kB)
     |████████████████████████████████| 52 kB 1.6 MB/s 
?25hRequirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.8.1)
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2021.9.0-py3-none-any.whl (123 kB)
     |████████████████████████████████| 123 kB 52.8 MB/s 
?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.0)
Collecting aiohttp
  Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)
     |████████████████████████████████| 1.3 MB 44.2 MB/s 
?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.1.5)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.19.5)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)
Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.4)
Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.62.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<0.1.0,>=0.0.14->datasets) (3.7.4.3)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<0.1.0,>=0.0.14->datasets) (3.0.12)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (2.4.7)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2021.5.30)
Collecting multidict<7.0,>=4.5
  Downloading multidict-5.1.0-cp37-cp37m-manylinux2014_x86_64.whl (142 kB)
     |████████████████████████████████| 142 kB 49.3 MB/s 
?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.2.0)
Collecting async-timeout<4.0,>=3.0
  Downloading async_timeout-3.0.1-py3-none-any.whl (8.2 kB)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.6.3-cp37-cp37m-manylinux2014_x86_64.whl (294 kB)
     |████████████████████████████████| 294 kB 54.0 MB/s 
?25hRequirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.5.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2018.9)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)
Installing collected packages: multidict, yarl, async-timeout, fsspec, aiohttp, xxhash, huggingface-hub, datasets
Successfully installed aiohttp-3.7.4.post0 async-timeout-3.0.1 datasets-1.12.1 fsspec-2021.9.0 huggingface-hub-0.0.17 multidict-5.1.0 xxhash-2.0.2 yarl-1.6.3
Requirement already satisfied: sklearn in /usr/local/lib/python3.7/dist-packages (0.0)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sklearn) (0.22.2.post1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sklearn) (1.0.1)
Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sklearn) (1.19.5)
Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sklearn) (1.4.1)
Collecting transformers
  Downloading transformers-4.11.0-py3-none-any.whl (2.9 MB)
     |████████████████████████████████| 2.9 MB 5.2 MB/s 
?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
     |████████████████████████████████| 636 kB 45.4 MB/s 
?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
     |████████████████████████████████| 895 kB 50.8 MB/s 
?25hRequirement already satisfied: huggingface-hub>=0.0.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.17)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
     |████████████████████████████████| 3.3 MB 38.1 MB/s 
?25hRequirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (2.4.7)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)
Installing collected packages: tokenizers, sacremoses, pyyaml, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed pyyaml-5.4.1 sacremoses-0.0.46 tokenizers-0.10.3 transformers-4.11.0

필요 라이브러리 import 및 device 설정

import torch
import torch.nn as nn
import sklearn.metrics

from tqdm import tqdm
from datasets import load_dataset
from datasets.arrow_dataset import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

02 Data Loading¶

huggingface의 datasets 라이브러리의 load_dataset으로 데이터를 다운받습니다.

Data Download

  • train 데이터셋 : 32470개

  • validation 데이터셋 : 7765개

dataset = load_dataset('klue', 're')
dataset
Downloading and preparing dataset klue/re (download: 5.41 MiB, generated: 13.07 MiB, post-processed: Unknown size, total: 18.48 MiB) to /root/.cache/huggingface/datasets/klue/re/1.0.0/55ff8f92b7a4b9842be6514ce0b4b5295b46d5e493f8bb5760da4be717018f90...
Dataset klue downloaded and prepared to /root/.cache/huggingface/datasets/klue/re/1.0.0/55ff8f92b7a4b9842be6514ce0b4b5295b46d5e493f8bb5760da4be717018f90. Subsequent calls will reuse this data.
DatasetDict({
    train: Dataset({
        features: ['guid', 'sentence', 'subject_entity', 'object_entity', 'label', 'source'],
        num_rows: 32470
    })
    validation: Dataset({
        features: ['guid', 'sentence', 'subject_entity', 'object_entity', 'label', 'source'],
        num_rows: 7765
    })
})

Data View

dataset의 구조를 데이터 샘플을 통해 살펴보겠습니다.

# 데이터 구성
dataset['train'][0]
{'guid': 'klue-re-v1_train_00000',
 'label': 0,
 'object_entity': {'end_idx': 18,
  'start_idx': 13,
  'type': 'PER',
  'word': '조지 해리슨'},
 'sentence': '〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.',
 'source': 'wikipedia',
 'subject_entity': {'end_idx': 26,
  'start_idx': 24,
  'type': 'ORG',
  'word': '비틀즈'}}

관계 추출(Relation Extraction)의 구조는 다음과 같습니다.

  1. guid : 고유 index

  2. label : 관계 라벨(0 ~ 29)

  3. object_entity : 개체 엔티티 정보(시작 글자 위치, 끝나는 글자 위치, 단어, 단어 타입(ORG, PER 등)

  4. subject_entity : 주체 엔티티 정보(시작 글자 위치, 끝나는 글자 위치, 단어, 단어 타입(ORG, PER 등)

  5. sentence : 제시 문장

Data Sample

각 column 구성을 임의의 3개의 샘플을 추출하여 살펴보겠습니다.

import random
for i in random.sample(range(0, 32470), 3) :
    print(str(i) + '번째 Data')
    print('guid : ' + dataset['train'][i]['guid'])
    print('label : ' + str(dataset['train'][i]['label']))
    print('object_entity : ' + str(dataset['train'][i]['object_entity']))
    print('subject_entity' + str(dataset['train'][i]['subject_entity']))
    print('sentence : ' + dataset['train'][i]['sentence'])
    print('=============================================', end='\n\n')
17273번째 Data
guid : klue-re-v1_train_17273
label : 5
object_entity : {'word': '네덜란드', 'start_idx': 45, 'end_idx': 48, 'type': 'ORG'}
subject_entity{'word': '퀴라소', 'start_idx': 16, 'end_idx': 18, 'type': 'ORG'}
sentence : 2007년부터 2013년까지 퀴라소 U-20 대표팀, 네덜란드 U-19 대표팀, 네덜란드 U-21 대표팀에서 25경기 1골을 기록했고 이후 2016년 퀴라소 A대표팀에 처음으로 발탁되어 같은 해 3월 바베이도스와의 친선 경기에서 국제 A매치 첫 경기를 치뤘으며 현재까지 A매치 28경기에서 11골을 뽑아내면서 퀴라소의 2017년 카리브컵 우승, 2019년 킹스컵 우승, 2019년 CONCACAF 골드컵 8강 진출에 공헌했다.
=============================================

12499번째 Data
guid : klue-re-v1_train_12499
label : 18
object_entity : {'word': '베르놀라크 육군', 'start_idx': 0, 'end_idx': 7, 'type': 'PER'}
subject_entity{'word': '페르디난트 차틀로시', 'start_idx': 24, 'end_idx': 33, 'type': 'PER'}
sentence : 베르놀라크 육군은 슬로바키아 국방부 장관인 페르디난트 차틀로시가 지휘했으며, 초기 지휘부는 스피슈스카노바베스였으나 9월 8일 이후 프레쇼우 근처로 옮겼다.
=============================================

27819번째 Data
guid : klue-re-v1_train_27819
label : 10
object_entity : {'word': '김영삼', 'start_idx': 15, 'end_idx': 17, 'type': 'PER'}
subject_entity{'word': '문민정부', 'start_idx': 3, 'end_idx': 6, 'type': 'ORG'}
sentence : 이는 문민정부가 들어섰다는 김영삼 정권 때 632명보다 두 배에 가까운 수치이다.
=============================================

RE task 의 목적은 주어진 문장 내의 두 entity가 어떤 토픽(Topic)에 속하는지 분류하는 것입니다.

이때 input 에는 entity type이 표시된 토큰으로 각 entity들을 감싼 문장이 주어지며 label이 target으로 사용된니다

Type Tokens를 추가하는 이유

추가한 token은 transformer 전체층을 다 거치고 나면 sequence의 결합된 의미를 가지게 되는데, 여기에 간단한 classifier를 붙이면 단일 문장의 classification을 쉽게 할 수 있게 됩니다.

즉 문장을 분류함에 있어서 제시된 entity들이 특별한 의미를 갖도록 하기 위해 type tokens들을 추가하는것!

03 Data Processing

KLUE-RE task 중 토큰 추가와 Tokenizer를 사용하여 데이터를 인코딩한 후 전처리하는 과정을 설명합니다.

토큰 추가

sentence 에 관계 단어 두개를 표시하기 위한 각각의 엔티티 타입을 특별 토큰형태로 추가합니다.

def add_Token(dataset):
    sentences = []
    labels = []

    for data in dataset:
        sentence = data['sentence']

        object_start =  int(data['object_entity']['start_idx'])
        object_end =  int(data['object_entity']['end_idx'])
        subject_start =  int(data['subject_entity']['start_idx'])
        subject_end =  int(data['subject_entity']['end_idx'])
        otype = data['object_entity']['type']
        stype = data['subject_entity']['type']

        if object_start < subject_start:
            new_sentence = sentence[:object_start] + '<O-' + str(otype) + '>' + sentence[object_start:object_end+1] + '</O-' + str(otype)+'>' + sentence[object_end+1:subject_start] + '<S-'+str(stype)+'>' + sentence[subject_start:subject_end+1] + '</S-'+str(stype)+'>' + sentence[subject_end+1:]
        else:
            new_sentence = sentence[:subject_start] + '<S-'+str(stype)+'>' + sentence[subject_start:subject_end+1] + '</S-'+str(stype)+'>' + sentence[subject_end+1:object_start] + '<O-'+str(otype)+'>' + sentence[object_start:object_end+1] + '</O-'+str(otype)+'>' + sentence[object_end+1:]

        # 본문 저장
        sentences.append(new_sentence)

        # 레이블 저장
        labels.append(data['label'])

    return sentences, labels
# train, validation데이터셋에서 sentence와 label만 저장.
train_sentences, train_labels = add_Token(dataset['train'])
val_sentences, val_labels = add_Token(dataset['validation'])
# 토큰 확인하기.
for sentence in train_sentences[:5]:
    print(sentence, '\n')
〈Something〉는 <O-PER>조지 해리슨</O-PER>이 쓰고 <S-ORG>비틀즈</S-ORG>가 1969년 앨범 《Abbey Road》에 담은 노래다. 

호남이 기반인 바른미래당·<O-ORG>대안신당</O-ORG>·<S-ORG>민주평화당</S-ORG>이 우여곡절 끝에 합당해 민생당(가칭)으로 재탄생한다. 

K리그2에서 성적 1위를 달리고 있는 <S-ORG>광주FC</S-ORG>는 지난 26일 <O-ORG>한국프로축구연맹</O-ORG>으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다. 

균일가 생활용품점 (주)<S-ORG>아성다이소</S-ORG>(대표 <O-PER>박정부</O-PER>)는 코로나19 바이러스로 어려움을 겪고 있는 대구광역시에 행복박스를 전달했다고 10일 밝혔다. 

<O-DAT>1967</O-DAT>년 프로 야구 드래프트 1순위로 <S-ORG>요미우리 자이언츠</S-ORG>에게 입단하면서 등번호는 8번으로 배정되었다. 

Tokenizer load

전처리를 위해 tokenizer로 데이터를 인코딩하는 과정이 필요합니다. transformers 라이브러리의 tokenizer 모듈로 모델이 입력받는 포맷으로 변환할 수 있습니다.

본 노트북에서는 AutoTokenizer.from_pretrained를 이용해 KLUE의 bert base pretrianed model에서 사용된 tokenizer로 데이터를 인코딩 하겠습니다.

model = 'klue/bert-base'
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base')
/usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py:337: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "

토큰화 결과 보기

ex_encoding = tokenizer(dataset['train'][0]['sentence'],
                        max_length=128,
                        padding='max_length',
                        truncation=True)
ex_encoding
{'input_ids': [2, 168, 30985, 14451, 7088, 4586, 169, 793, 8373, 14113, 2234, 2052, 1363, 2088, 29830, 2116, 14879, 2440, 6711, 170, 21406, 26713, 2076, 25145, 5749, 171, 1421, 818, 2073, 4388, 2062, 18, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

token 추가하기

sentence에 추가한 subject, object 토큰들을 토크나이저에 등록해야 해당 token들이 일반 문자로 인식되지 않고 token화 되지도 않습니다.

  1. 모든 경우의 타입들 저장

#train
types = []
for i in range(32470):
    if dataset['train'][i]['subject_entity']['type'] not in types:
        types.append(dataset['train'][i]['subject_entity']['type'])
    
    if dataset['train'][i]['object_entity']['type'] not in types:
        types.append(dataset['train'][i]['object_entity']['type'])

#validation
for i in range(7765):
    if dataset['validation'][i]['subject_entity']['type'] not in types:
        types.append(dataset['validation'][i]['subject_entity']['type'])
    
    if dataset['validation'][i]['object_entity']['type'] not in types:
        types.append(dataset['validation'][i]['object_entity']['type'])

entity_types = []

for i in types:
    entity_types.append('S-' + str(i))
    entity_types.append('O-' + str(i))

entity_types
['S-ORG',
 'O-ORG',
 'S-PER',
 'O-PER',
 'S-DAT',
 'O-DAT',
 'S-LOC',
 'O-LOC',
 'S-POH',
 'O-POH',
 'S-NOH',
 'O-NOH']

entity type : ORG, PER, DAT, LOC, POH, NOH

각각의 타입들을 subject, object로 변환하면, 최종적으로 12개의 type tokens 생성됩니다.

  1. 타입 토큰 추가

new_enrollment_tokens = {'additional_special_tokens': entity_types}
enrollment_tokens = tokenizer.add_special_tokens(new_enrollment_tokens)

04 Data Loader

학습에 사용하는 데이터셋과 데이터 로더 만들기

학습 setting 값 설정

#dataLoader 
batch_size = 8

# model
num_labels = 30

# train
learning_rate = 1e-5
weight_decay = 0.0

setting 값 설명

  • batch_size : 데이터 로더 생성시 필요한 배치 사이즈 설정

  • num_labels : 모델 불러올때 필요한 label 개수

  • learning_rate : 학습시의 학습률

  • weight_decay : Weight가 커질경우에 대한 패널티 항목

학습 데이터셋과 데이터로더 만들기

class makeDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, sentences, labels, max_length=128):
        self.encodings = tokenizer(sentences,
                                   max_length=max_length,
                                   padding='max_length',
                                   truncation=True)
        self.labels = labels
    
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item
    
    def __len__(self):
        return len(self.labels)

train_dataset = makeDataset(tokenizer, train_sentences, train_labels)
val_dataset = makeDataset(tokenizer, val_sentences, val_labels)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

05 Fine-tuning

KLUE-RE task 중 Pretrained model 을 사용하여 fine-tuning을 합니다.

Model load¶

Pretrained model을 다운 받아 fine tuning 을 진행합니다. RE task는 분류와 관련한 task 이므로 , AutoModelForSequenceClassification 클래스를 사용하며, 이때는 label 개수에 대한 설정이 필요합니다.

model = AutoModelForSequenceClassification.from_pretrained(model, num_labels=num_labels).to(device)
/usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py:337: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
model
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (2): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (3): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (4): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (5): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (6): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (7): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (8): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (9): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (10): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (11): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=30, bias=True)
)

##Embedding Layer을 resize¶ Bert에는 토큰들의 id에 따른 임베딩 값을 반환하는 Embedding Layer가 존재합니다.

하지만 불러온 Embedding Layer에는 추가한 타입 토큰에 대한 정보가 반영 안된 상태여서 추가한 토큰들이 입력으로 주어질 경우 index error가 발생합니다.

따라서 Bert의 Embedding Layer를 input 차원이 32000에서 32012로 추가한 토큰 12만큼 증가하도록 resize합니다.

model.resize_token_embeddings(len(tokenizer))
Embedding(32016, 768)

학습 시의 Loss, Accuracy 계산 및 저장을 간단하게 하기 위해 AverageMeter를 클래스를 이용합니다.

class AverageMeter():
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

model fine-tuning

BERT-base 모델을 fine-tuning합니다.

def train_model(data_loader, model, criterion, optimizer, train=True):
    loss_save = AverageMeter()
    acc_save = AverageMeter()
    
    # progress bar 생성
    for _, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'token_type_ids': batch['token_type_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device),
        }
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(**inputs)
        logits = outputs['logits']
        
        loss = criterion(logits, labels)

        if train:
            loss.backward()
            optimizer.step()
        
        preds = torch.argmax(logits, dim=1)
        acc = ((preds == labels).sum().item() / labels.shape[0])
        
        loss_save.update(loss, labels.shape[0])
        acc_save.update(acc, labels.shape[0])
        
    results = {
        'loss': loss_save.avg,
        'acc': acc_save.avg,
    }
    
    return results
        

epochs = 1

# loss function, optimizer 설정
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(epochs):
    print(f'< Epoch {epoch+1} / {epochs} >')
    
    # Train
    model.train()
    train_results = train_model(train_loader, model, criterion, optimizer)
    train_loss, train_acc = train_results['loss'], train_results['acc']
    
    # Validation
    with torch.no_grad():
        model.eval()
        
        val_results = train_model(val_loader, model, criterion, optimizer, False)
        val_loss, val_acc = val_results['loss'], val_results['acc']
    
    
    print(f'train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}')
    print('=' * 100)
< Epoch 1 / 1 >
100%|██████████| 4059/4059 [27:03<00:00,  2.50it/s]
100%|██████████| 971/971 [02:04<00:00,  7.81it/s]
train_loss: 0.5199, train_acc: 0.8220, val_loss: 0.7974, val_acc: 0.7378
====================================================================================================

06 Test

결과를 확인 해봅니다.

for i in range(20):
    val = encoding = tokenizer(val_sentences[i], max_length=128, padding='max_length', truncation=True, return_tensors='pt')
    val_input = {
    'input_ids': val['input_ids'].to(device),
    'token_type_ids': val['token_type_ids'].to(device),
    'attention_mask': val['attention_mask'].to(device),
    }
    model.eval()
    output = model(**val_input)
    label = torch.argmax(output['logits'], dim=1)
    print('label : ' + str(val_labels[i]) + ', 추측 값 : ' + str(label))
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 0, 추측 값 : tensor([10], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 18, 추측 값 : tensor([18], device='cuda:0')
label : 17, 추측 값 : tensor([17], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 10, 추측 값 : tensor([10], device='cuda:0')
label : 10, 추측 값 : tensor([0], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 6, 추측 값 : tensor([0], device='cuda:0')
label : 3, 추측 값 : tensor([3], device='cuda:0')
label : 8, 추측 값 : tensor([8], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 29, 추측 값 : tensor([29], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 0, 추측 값 : tensor([0], device='cuda:0')
label : 18, 추측 값 : tensor([18], device='cuda:0')

처음 20개의 validation 데이터 비교시 3개를 제외한 데이터들이 일치했습니다

F1 score로 모델 평가하기

def calc_f1_score(preds, labels):
    preds_relation = []
    labels_relation = []
    
    for pred, label in zip(preds, labels):
        preds_relation.append(pred)
        labels_relation.append(label)

    f1_score = sklearn.metrics.f1_score(labels_relation, preds_relation, average='micro', zero_division=1)
    
    return f1_score * 100
with torch.no_grad():
    model.eval()
    
    label_all = []
    pred_all = []
    for batch in tqdm(val_loader):
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'token_type_ids': batch['token_type_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device),
        }
        labels = batch['labels'].to(device)
        
        outputs = model(**inputs)
        logits = outputs['logits']
        
        preds = torch.argmax(logits, dim=1)
        
        label_all.extend(labels.detach().cpu().numpy().tolist())
        pred_all.extend(preds.detach().cpu().numpy().tolist())
    
    f1_score = calc_f1_score(label_all, pred_all)
100%|██████████| 971/971 [02:00<00:00,  8.09it/s]

F1 score : 73.7797

f1_score
73.77978106889891