본문 바로가기

데이터사이언스/머신러닝

[강화학습] CartPole에서 Actor-Critic 구현하기

 

1. Actor-Critic

Actor Critic은 Deep RL에서 정책함수와 가치함수를 모두 학습하는 방식이다. 이 방식은 확률적 정책을 취하기 때문에 결정론적 방식을 사용하는 가치기반 방식에 비해 변화하는 주변 상황에 적용하기 용이하다. 또한 액션 공간이 연속적인 경우에도 적용할 수 있다는 장점이 있다.

 

2. import 및 hyperparameter 정의

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98
n_rollout     = 10 # 몇번의 step 마다 데이터를 업데이트 할지

3. main 함수

Actor Critic 모델을 먼저 선언한다. 그리고 정책함수 pi에서 모델의 행동을 고르고 해당 상황의 데이터를 모델에 입력해 학습을 진행한다.

def main():  
    env = gym.make('CartPole-v1')
    model = ActorCritic()    
    print_interval = 20
    score = 0.0

    for n_epi in range(10000):
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)
                model.put_data((s,a,r,s_prime,done))
                
                s = s_prime
                score += r
                
                if done:
                    break                     
            
            model.train_net()
            
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
            score = 0.0
    env.close()

if __name__ == '__main__':
    main()

4. Actor-Critic class

 이 클래스에서는 먼저 __init__을 통해 학습을 진행할 신경망을 정의한다. 다음으로 pi를 통해 현재 상황이 x로 들어오면, 각 action의 가능성을 출력하는 prob를 리턴하는 정책함수를 정의하였다. 또한 v를 통해 현재 상황이 x로 들어오면, 현재 상황에 대한 가치함수를 출력한다.

 

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def pi(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s,a,r,s_prime,done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                                               torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
                                                               torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch

 

5. train_net

 Actor Critc의 train_net()에서는 실제 agent의 학습을 진행한다. 먼저 다음 시점 s_prime의 state vlaue값에서 지금 현재의 state value 값의 차이인 delta를 구한다. 이제 pi 함수를 호출해 주어진 상황에서 각 액션의 확률 pi를 구한다. 이제 batch에서 뽑았던 action의 확률을 구한다. 이제 정책함수와 가치함수의 loss function을 한번에 구한다.loss 함수의 앞에 있는 정책함수의 loss는 -(각 행동을 할 값의 로그 값) * (행동으로 인해 변할 가치)로 정의하여, 높은 가치를 가지는 행동을 최대화하도록 한다. delta는 상수값으로 고정해두기 위해 detach()를 사용했다. value funtion의 loss는  (다음 시점의 가치함수인 td_target)  - (현재 상태의 가치함수)으로 설정하였다. td_target은 detach()로 상수로 설정하였고, 이를 통해서 현재 상태의 가치를 최대화하여 다음 상태의 가치와 차이를 줄이는 방향으로 학습을 진행한다.

def train_net(self):
    s, a, r, s_prime, done = self.make_batch()
    td_target = r + gamma * self.v(s_prime) * done
    delta = td_target - self.v(s)

    pi = self.pi(s, softmax_dim=1)
    pi_a = pi.gather(1,a)
    loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())

    self.optimizer.zero_grad()
    loss.mean().backward()
    self.optimizer.step()

 

 

 

https://github.com/seungeunrho/RLfrombasics

 

GitHub - seungeunrho/RLfrombasics: provides all the codes from the book "RL from basics(바닥부터 배우는 강화학습)"

provides all the codes from the book "RL from basics(바닥부터 배우는 강화학습)" - GitHub - seungeunrho/RLfrombasics: provides all the codes from the book "RL from basics(바닥부터 배우는 강화학습)"

github.com

이 책의 내용을 참고하여 작성했습니다.