1. 기본 개념
GAN(Generative Adversarial Networks, 생성적 적대 신경망)이란 비지도학습에 사용되는 머신러닝 프레임워크의 한 종류이다. GAN은 다른 알고리즘과는 달리 이전에는 없던 새로운 데이터를 생성할 수 있다. 이 알고리즘은 흔히 경찰과 위조지폐범 사이의 게임에 비유된다. 위조지폐범은 진짜 같은 화폐를 만들어 경찰을 속이기 위해 노력하고, 경찰은 위조지폐를 잘 감별하기 위해 노력한다. GAN에서도 생성모델(generator)은 최대한 진짜 같은 데이터를 만들기 위한 학습을 진행하고, 분류모델(discriminator)은 진짜와 가짜를 판별하기 위한 학습을 진행한다.
GAN의 학습 과정에서는 분류모델을 먼저 학습시키고, 생성모델을 학습시킨다. 분류모델은 먼저 진짜 데이터를 진짜로 분류하도록 학습시킨다. 다음으로 생성모델이 생성한 데이터를 가짜로 분류하도록 학습시킨다. 마지막으로 학습된 분류모델을 속이는 방향으로 생성모델을 학습시킨다.
생성모델은 노이즈를 입력으로 받아 다수의 층을 통과하면서 특징 맵을 확장시켜나가는 구조로 이루어져있다. 마지막 층을 통과해서 나오는 특징 맵은 이미지 크기와 같다. 반대로 분류모델은 특징맵의 크기를 줄여나가는 구조로, 전통적인 인공신경망의 구조를 따르고 있다.
2. 손실함수
분류모델 D는 위 손실함수의 값을 최대화시켜야하고, 생성모델 G는 식의 값을 최소화시켜야한다. 위 수식에서 D(x)는 x가 모델에 입력되었을 때 분류모델이 판단한 진짜일 확률이고, 0~1의 범위로 표현된다. G(z)는 z라는 노이즈가 입력되면 이를 바탕으로 생성모델이 생성한 가짜 데이터이다. 그리고 X~Pdata(x)는 실제 데이터에서 샘플링한 데이터, Z~PZ(Z)는 정규분포를 사용하는 임의의 노이즈에서 샘플링한 데이터를 의미한다. 여기서 Z는 latent vector라고도 불리는데, 차원이 줄어든 채로 데이터의 분포를 잘 설명할 수 있는 잠재 공간에서의 벡터를 의미한다.
먼저 분류모델의 입장에서 본다면 식의 값을 최대화시키기 위해서는 D(x) = 1, D(G(z)) = 0이 되어야 한다. 결국 생성모델이 만들어낸 데이터를 가짜로, 진짜 데이터를 진짜로 판별해야하는 것이다. 다음으로 생성모델의 입장에서 보면 식의 값을 최소화시키기 위해서는 D(G(z)) = 1이 되어야 한다. 분류모델이 진짜로 판단할만한 데이터를 만들어야 하는 것이다.
3. DCGAN(Deep Convolutional Generative Adverial Networks)
GAN이 2014년 처음 발표된 이후로 학습이 불안정하다는 문제가 제기되었다. 이를 해결하기 위해 2016년 구글에서 발표한 모델로, 현재 개발되고 있는 GAN 구조의 기초가 되는 구조이다. 기존의 GAN에서 fully-connected로 구성 되어 있었던 생성모델과 분류모델을 convolution으로 대체 구성하여 성능과 안전성을 높인 구조를 가진다. 또한 배치 정규화 기법을 사용해서 더욱 빠른 학습을 가능하게 한다.
특이한 점은 생성모델에서 사용하는 fractional strided convolution이다. 이 방법은 input 사이에 padding을 더하고 convoluton을 하면서 위 그림과 같이 크기가 오히려 더 커진다. 생성모델은100차원의 latent vector를 입력받으면 64x64의 이미지를 반환한다.
DCGAN은 위 그림처럼 오른쪽과 왼쪽을 보는 데이터를 입력해서 그 중간을 바라보는 결과물도 만들어 낼 수 있다. 이는 왼쪽과 오른쪽을 바라보는 얼굴을 만들어내는 latent vector의 평균을 계산해서 그 중간의 값들을 입력해서 얻은 결과이다. 이를 통해서 생성모델이 데이터의 확률분포를 정확히 파악하는 것을 알 수 있다. 또한 안경 쓴 남자의 이미지에서 안경을 안쓴 남자의 이미지를 빼고 안경을 안 쓴 여자의 이미지를 더해서 안경을 쓴 여자의 이미지를 만드는 등 이미지 연산이 가능하다는 것을 보여준다. 이미지 생성의 결과물을 자유롭게 조작할 수 있는 가능성을 확인한 것이다.
'데이터사이언스 > 머신러닝' 카테고리의 다른 글
[강화학습] CartPole에서 Actor-Critic 구현하기 (0) | 2022.11.02 |
---|---|
BERT와 pytorch를 사용한 binary classification (0) | 2022.10.31 |
마르코프 프로세스란? (0) | 2022.09.13 |
RNN과 LSTM 이해하기 (0) | 2022.03.09 |
Gradient Boosting Algorithm (0) | 2022.03.01 |