본문 바로가기

데이터사이언스/머신러닝

RNN과 LSTM 이해하기

1. RNN의 정의

 RNN(Recurrent Neural Networks, 순환신경망)은 음성, 문자 등 순차적 데이터 처리에 적합한 모델이다. 내부의 연결이 순환적이라는 특징을 가지고 있다.

 

2. RNN의 구조

RNN은 위 그림처럼 입력 벡터 $x$와 출력 벡터 $y$, 은닉층 cell로 구성되어 있다. 여기서 cell은 이전의 값을 기억하려고 하는 메모리 역할을 수행하므로 이를 메모리 셀이라고도 한다. 은닉층의 메모리 셀은 이전 시점 t-1에서 셀에서 나온 값을 현재 시점 t에서 자신의 입력으로 사용하는 재귀적 활동을 하고 있다. 여기서 메모리 셀이 출력층 $y$ 방향으로 보내는 값, 다음 시점의 자기 자신에게 보내는 값은 은닉 상태(hidden state)라고 한다.

 

 RNN을 표현할 때는 왼쪽처럼 재귀적 표현으로 나타내기도 하고, 오른쪽처럼 여러 시점을 펼쳐서 나타내기도 한다. 두 그림은 모두 같은 의미를 가지고 있다.

 

이제 RNN의 수식을 살펴보자. 현재 시점 t에서 은닉 상태값을 $h_{t}$라고 정의하자. 이 값은 이전 시점 t-1에서의 은닉 상태 값, 그리고 입력층 $x$에서 입력을 받아서 갱신된다. 그리고 두 가지 신호는 각각 학습을 위한 가중치를 가진다.  이전 시점 t-1의 은닉 상태 값 $h_{t-1}$은 가중치 $W_{hh}$를, 현재 시점의 입력값 $x_{t}$는 가중치 $W_{xh}$를 가진다. 각 가중치와 입력값을 곱하고, 편향을 더한 뒤 활성함수 tanh에 값을 넣으면 현재 시점의 은닉 상태 값 $h_{t}$를 구할 수 있다.

 

3. RNN의 예시

RNN은 다양한 분야에 활용될 수 있다. 위 그림처럼 출력 값이 여러개인 모델은 객체명을 인식하거나 번역에 사용될 수 있다.

위 그림은 출력값이 없고 마지막 층의 은닉 상태값을 출력으로 활용하는 경우이다.  이런 모델은 스팸메일을 분류하거나 주어진 문장이 긍정적인지 부정적인지 감정분석에 사용될 수 있다.

 

4. LSTM의 등장 배경

RNN으로 다음에 올 단어를 예측하는 모델을 생각해보자. "I grew up in France...(여러 문단).. I speak fluent French" 라는 문장이 주어졌다고 가정하자. 마지막에 올 단어를 예측하기 위해서는 한참 앞에 있는 France라는 단어가 필요하다. 하지만 일반적인 RNN에서는 입력값과 그 값을 사용하는 지점의 거리가 멀 수록 학습 능력이 크게 떨어지는 gradient vanishing problem이 존재한다.

 이런 문제를 해결하기 위해 1997년 LSTM(Long Short-Term Memory)이 등장했다. LSTM은 오랜 기간 정보를 저장할 수 있도록 설계되었다.

 

5. LSTM의 구조

LSTM은 RNN의 은닉 상태 값(hidden state)에 cell state $C_{t}$가 추가되었다.

cell state는 기존의 정보를 저장하기 위한 역할을 한다. 따라서 state가 오래 경과하더라도(input과 그 input이 사용되는 위치 사이의 거리가 길어도) 학습이 잘 진행되도록 한다. 

cell state에 영향을 주는 gate는 두 가지가 있다. 

 

1. forget gate 이전 cell state $C_{t-1}$에 여기서 입력된 0 ~ 1 사이의 값을 곱해 과거 정보를 얼마나 기억할지 결정하는 gate이다. $x_{t}$와 $h_{t-1}$을 입력 받아서 시그모이드를 취해준 값이 이 gate가 내보내는 값이 된다. 0에 가까울수록 기존 값을 많이 잊고, 1에 가까울수록 기존 값을 많이 기억한다.

 

2. input gate 현재 정보를 얼마나 기억할지 결정하는 게이트이다.  $x_{t}$와 $h_{t-1}$을 받아서 시그모이드를 취한다. 또 같은 입력으로 tanh를 취하고 두 값을 곱해준다. 구한 값을 $C_{t-1}$에 더한다.

LSTM의 output gate layer

마지막으로 hidden state에 시그모이드를 적용해 어떤 값을 내보낼지 정한다. 그리고 cell state에 tanh를 적용하고 두 값을 곱해 최종적인 출력을 구한다.