본문 바로가기

데이터사이언스/머신러닝

BERT와 pytorch를 사용한 binary classification

colab 환경에서 BERT와 pytorch를 사용해서 text를 binary classification을 했다.

 

1. BERT의 기본 구성

https://ebbnflow.tistory.com/151

 

[BERT] BERT에 대해 쉽게 알아보기1 - BERT는 무엇인가, 동작 구조

● 언어모델 BERT BERT : Pre-training of Deep Bidirectional Trnasformers for Language Understanding 구글에서 개발한 NLP(자연어처리) 사전 훈련 기술이며, 특정 분야에 국한된 기술이 아니라 모든 자연어..

ebbnflow.tistory.com

 

 

BERT란 wikipedia와 BookCorpus의 데이터를 사전학습한 transformer 언어모델이다. BERT의 input은 다음 세가지로 구성된다.

1) Token Embedding

기본적으로 글자 단위로 임베딩을 하고, 자주 등장하는 글자는 합쳐 하나의 단위로 만든다. 

2) Segment Embedding

BERT에서는 두개의 문장을 하나의 Segment로 다루는데, 그 중에서 어던 문장에 포함되는지 나타낸다.

3) Position Embedding

전체  Input에서 해당 token이 해당하는 위치를 나타낸다.

 

이런 Input을 가지고, 임의의 토큰을 버리고 해당 토큰을 맞추는 MLM, 두 문장이 주어졌을 때 다음 문장을 예측하는 NSP 방식으로 사전학습을 진행한 모델이 BERT다. 여기에 자신의 데이터가 특정 분야에 특화되어 있다면 추가적으로 학습을 시켜 더 나은 성능을 보일 수 있다.

 

2. 라이브러리 import 및 데이터 읽어오기

!pip3 install libauc==1.2.0
!pip3 install pytorch-transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from google.colab import drive

import random
import numpy as np
import pandas as pd
import os
import timm
import re
from libauc.losses import AUCMLoss
from libauc.optimizers import PESG

from sklearn.metrics import roc_auc_score
from torchmetrics import AUROC
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from transformers import AutoTokenizer, AutoModel
from pytorch_transformers import BertTokenizer, BertForSequenceClassification, BertConfig
drive.mount('/content/drive')
records = pd.read_csv('/content/drive/MyDrive/kium/TrainSet _1차.csv')
records = records.iloc[:, :].fillna('')
records.head()

3. dataloader 정의 & 토큰화

# hyperparameter 정의
BATCH_SIZE = 8
lr = 1e-5
EPOCHS = 4
NUM_CLASSES = 2

# 학습 DEVICE 정의
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("cpu와 cuda 중 다음 기기로 학습함:", DEVICE)

train_df, valid_df = train_test_split(records, test_size=0.2, stratify=records.AcuteInfarction)
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

class KiDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.padded_value = []
        for idx, row in df.iterrows():
            texts = ' '.join([row[0], row[1]])
            #토큰화 
            encoded_value = tokenizer.encode(texts, add_special_tokens=True)[:512]
            np_encoded = np.array(encoded_value)
            if np_encoded.ndim == 1:
                padded_list = torch.tensor(encoded_value + [0] * (512-len(encoded_value)))
            else:
              padded_list = torch.tensor(e + [0] * (512-len(e)) for e in encoded_value)
            self.padded_value.append(padded_list)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # text = torch.tensor(self.padded_value[idx])
        text = self.padded_value[idx]
        label = torch.tensor(self.df.iloc[idx, 2])
        return text, label
        
#데이터 로더 정의
train_dataset = KiDataset(train_df)
valid_dataset = KiDataset(valid_df)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

 

 

4. 모델 학습 및 AUC를 통한 평가

# training
import collections

print('TRAIN START')

best_valid_auc = 0 
fprs = []
tprs = []
thresholds = []

for epoch in range(EPOCHS):
    for idx, (train_data, train_labels) in tqdm(enumerate(train_loader)):
        model.train()
        train_data, train_labels  = train_data.to(DEVICE), train_labels.to(DEVICE)
        outputs = model(train_data, labels=train_labels)
        
        pred = F.softmax(outputs[1])[:, 1]
        loss = criterion(pred, train_labels)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if idx == 10:
          print("idx 10")

        # validation  
        if idx % 100 == 0 and idx != 0:
          model.eval()
          with torch.no_grad():    
            valid_total_pred = []
            valid_total_labels = [] 
            for jdx, (valid_data, valid_labels) in enumerate(valid_loader):
                valid_data, valid_labels = valid_data.to(DEVICE), valid_labels.to(DEVICE)
                outputs = model(valid_data, labels=valid_labels)
                #pred = torch.argmax(F.softmax(outputs[1]), dim=1)
                pred = F.softmax(outputs[1])[:, 1]
                valid_total_pred.append(pred.cpu().detach().numpy())
                valid_total_labels.append(valid_labels.cpu().numpy())

            valid_total_labels = np.concatenate(valid_total_labels)
            valid_total_pred = np.concatenate(valid_total_pred)

                
            print('valid_total_labels')
            print(valid_total_labels)
            collections.Counter(valid_total_labels)

            print('valid_total_pred')
            print(valid_total_pred)

            try:
              fpr, tpr, threshold = roc_curve(valid_total_labels, valid_total_pred)
              fprs.append(fpr)
              tprs.append(tpr)
              thresholds.append(threshold)
              auc = roc_auc_score(valid_total_labels, valid_total_pred)
              valid_auc_mean = np.mean(auc) 
            except ValueError:
              valid_auc_mean = 0

            if best_valid_auc < valid_auc_mean:
                best_valid_auc = valid_auc_mean
                torch.save(model.state_dict(), f'pretrained_model_epoch_{epoch}.pth')

            print(f'Epoch: {epoch}\tIter: {idx}\tValid AUC: {valid_auc_mean:.3f}\tBest_Valid_AUC: {best_valid_auc:.3f}')

print('TRAIN END')