[arXiv 2013] Playing Atari with Deep Reinforcement Learning
Paper url: https://arxiv.org/pdf/1312.5602.pdf
Last updated
Paper url: https://arxiv.org/pdf/1312.5602.pdf
Last updated
Q-learning 사례에서 언급하였듯이 고차원의 state-action space에서 table 기반의 강화학습은 적용할 수 없거나 많은 양의 memory resource를 필요로하기 때문에 적합하지 않다. 이를 위해 실제 q-value를 도출 후 table에 저장하는 방식을 사용하지 않고 특정 함수(e.g. linear combination, decision tree, support vector machine)를 이용하여 근사하는 방법이 등장하였다. 다양한 함수를 통해서 q-value를 근사할 수 있으나 본 논문이 발표되는 시점과 맞물려 여러 task (특히 computer vision)에서 두각을 보이던 neural network이 주로 사용되며 이는 심층 강화학습(deep reinforcement learning)의 개념을 널리 알리게 되었다.
본 논문은 2013년 arXiv preprint 후 세부 내용을 추가하여 2015년 Nature에 게재 [1]된 바 있다. 논문에서 제안된 deep q-network (DQN)은 가치 기반 강화학습(value-based reinforcement learning)의 대표적인 알고리즘으로 자리 잡고 있다. Figure 02는 Atari-Breakout을 플레이하는 DQN agent의 모습을 볼 수 있으며 약 400번의 학습 횟수(episode) 후에는 사람이 하는 것과 유사한 성능을 보인다. 더 놀라운 점은 600번의 학습을 가진 agent는 사람도 생각하기 힘든 전략을 터득하는데, 한쪽의 벽돌을 뚫은 후 공을 그 위로 보내서 최소한의 움직임으로 많은 벽돌을 깨는 방법이 그것이다.
그렇다면 이러한 agent는 어떻게 만들 수 있을까? DQN을 설명하기에 앞서 neural network를 사용하여 q-value를 근사하는 가장 기초적인 알고리즘을 알아보자.
DQN은 강화학습에 neural network를 접목시킨 최초의 알고리즘이 아니다. 기존에 neural network 기반의 여러 시도들이 있었고 강화학습에는 신경망이 적합하지 않다는 의견이 다수였다. 기존의 방법은 어떤 문제점이 있었을까? 우선 q-value를 근사하는 network는 state를 input으로 받고 action에 대한 q-value를 output으로 삼는다. 그림으로 표현하면 Figure 03과 같다.
Neural network를 학습하기 위해서는 모델의 예측값과 정답이 필요하다. 예측값은 input state를 feed forward하면 구할 수 있고 정답은 무엇이 될까? 복잡하게 생각할 필요가 없다. Q-network algorithm은 앞에서 배운 q-learning 알고리즘을 neural network로 옮기는 과정에 불과하다 것을 잊으면 안된다. Q-learning에서 q-value를 갱신할 때 다음 식을 사용하였다.
Q-target에 대한 notation을 확인해보면, next state가 종료 상태(terminal state)일 경우는 보상을 target으로, non-terminal state라면 q-learning과 동일하게 보상 + 감가된 다음 상태의 maximum q-value를 target으로 삼는 것을 확인할 수 있다. 한 가지 다른 점은 q-target의 maximum q-value 또한 q-network로 추정된 값이라는 사실이다.
Q-network 알고리즘은 q-target과 state를 feed forward 했을 때 출력되는 예측값 사이의 차이를 줄이는 방향으로 학습을 진행한다. 이는 일반적인 regression task를 풀 때와 동일하게 mean squared error (MSE)를 줄이는 것과 동일하다. Q-network 알고리즘의 목적함수를 적으면 다음과 같다.
현재 시점의 Q-value와 다음 시점의 q-target이 neural network의 output이라는 사실을 제외하면 q-learning agent와 100% 동일하다. Q-network 알고리즘의 pseudo code는 다음과 같다.
기존 Q-learning 알고리즘에서 q-value를 신경망으로 근사하는 부분만 변경된 이 알고리즘은 어떠한 단점이 있을까? Q-network 알고리즘의 단점이자, DQN 이전의 neural network 기반의 강화학습 단점은 크게 3가지로 1) 부족한 학습 데이터 및 낮은 데이터 효율성, 2) high correlation between samples, 3) non-stationary target problem을 들 수 있다.
부족한 학습 데이터 및 낮은 데이터 효율성
일반적으로 neural network는 학습 데이터가 많은 supervised, unsupervised learning에 효율적인 방법이다. 강화학습에서 시간 차 학습(temporal difference learning, TD)의 경우 한번 update에 사용한 transition은 다시 사용하지 않는다. 이 때문에 neural network를 학습시키기에 충분한 데이터를 확보하기 힘들며 과거의 좋은 transition이 휘발되는 단점이 있다.
High correlation between samples
강화학습의 학습 데이터에 해당하는 transition은 time-step 간 correlation이 굉장이 높다. 이는 직전 transition에서 action에 의해 현재 및 다음 transition이 결정되기 때문이다. 일상생활에서 저녁 메뉴를 고르는 데 있어 점심, 혹은 어제 저녁에 먹었던 메뉴가 영향을 주는 것과 유사하다. Sample 간 correlation이 높을 경우 network가 거쳐온 transition 양에 비해 유의미한 학습량이 많지 않거니와(유사한 sample, 비슷한 action, 낮은 error), 초기 transition에서 선택한 action에 의해 앞으로의 sample들이 종속되는 현상이 발생한다. 따라서 global optimum에 도달하지 못하고 local minimum에 빠질 위험이 발생한다.
Non-stationary target problem
매 time-step 마다 q-network를 update하는 상황을 생각해보자. 앞서 q-network 알고리즘은 동일한 network를 이용하여 q-target과 q-value를 근사한다고 설명하였다. 우리는 최적의 q-value를 근사하기 위해서 q-network를 갱신하게 되는데, 동일한 network를 사용하는 알고리즘의 특성상 network 갱신 주기에 따라서 q-target의 근사값 또한 변하게 된다. 즉 업데이트 대상이 되는 q-target이 고정되어 있지 않고 갱신할 때마다 값이 달라지는 현상이 발생한다. 이를 non-stationary target problem이라 칭하며 흔히 움직이는 과녁에 화살을 맞추는 것으로 비유한다.
위 3가지 문제는 비단 q-network 알고리즘의 한계 뿐만 아니라 강화학습에서 neural network의 한계로 받아들여지고 있었다. DQN은 비교적 간단한 idea를 통해 이를 해결하였는데 다음 절을 통해서 알아보도록 하자.
RL-background의 취지에 맞추어 세부적인 preprocessing 과정은 생략하고 3가지 단점을 어떻게 개선하였는지 중점적으로 알아보자.
DQN은 앞서 언급한 3가지 단점 중 1) 부족한 학습 데이터와 낮은 데이터 효율성, 2) high correlation between samples 문제를 replay buffer의 도입으로 해결하였다. 이를 experience replay라고 부르는데, agent가 매 time-step 마다 획득한 transition <state, action, reward, next state> tuple을 buffer에 저장 후 꺼내쓰는 방식이다. Replay buffer는 first in first out 형태로 기록되며 사용자가 설정한 학습 주기마다 buffer에 저장된 transition을 random sampling하여 network를 학습한다. 도식화 하면 Figure 05와 같다.
Figure 05에서 큰 화살표의 경우 environment와 상호작용하는 path를 뜻하고 실선 화살표의 경우 CNN 학습에 사용되는 transition path를 뜻한다.
Environment와 상호작용하는 transition들을 replay buffer에 저장하는 것을 통해 부족한 학습 데이터 양을 보충할 수 있다. 또한 CNN을 학습할 때 buffer로 부터 random sampled transition을 활용함으로써 sample간 높은 correlation 문제를 해결하였다.
이와 같이 두 개의 network를 두고 각각 q-target과 q-value를 근사함으로써 non-stationary target problem을 해결할 수 있다.
정리하면 DQN은 replay buffer의 도입(experience replay)과 q-value를 추정하는 network와 독립된 q-target network를 사용함으로써 기존 q-network 알고리즘의 단점을 개선할 수 있었다. 복잡한 문제를 비교적 간단한 방법으로 해결했다는 생각이 들지 않는가? DQN의 pseudo code는 다음과 같다.
Q-network pseudo code와 달리 target을 추정하는 network와 q-value를 근사하는 network가 다른 것을 확인할 수 있으며 일정 주기마다 target network로 가중치를 복사하는 모습을 볼 수 있다.
DQN은 Atari 환경의 대부분의 게임에서 사람의 성능을 넘거나 그에 준하는 성능을 보였다. Figure 08은 2015 Nature에 게재된 논문에 기록된 성능 도표이다.
앞서 DQN은 experience replay와 fixed q-target을 통해 높은 성능을 달성할 수 있었다고 밝혔다. 그렇다면 이 두 가지 방법의 contribution은 어느정도 될까? 이에 대한 ablation study는 Figure 09와 같다.
제일 우측이 기존의 q-network algorithm에 해당한다. Breakout을 기준으로 fixed q-target은 q-network algorithm 대비 약 3배의 성능 향상을, experience replay는 약 80배의 성능 향상을 보여준다. 마지막으로 이 둘을 모두 합한 DQN은 약 100배의 성능 향상을 보여준다.
DQN은 앞서 언급하였듯 복잡한 문제를 비교적 간단한 방법으로 해결한 알고리즘이다. 또한 deep reinforce-ment learning (심층 강화학습)의 포문을 열어젖힌 알고리즘으로 기여하는 바가 커 Nature에도 등재되었다. DQN이 처음 나온 지 8년여의 시간이 흐른 만큼 이를 개선한 알고리즘이 많이 등장하였다.
여기까지 읽은 독자들은 과연 DQN의 어떤 점을 더 개선할 수 있을것으로 보이는가? 다음 장으로 넘어가기 전에 잠깐의 생각을 가져보면 좋을 것 같다.
[1] Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., ... & Hassabis, D. (2015). Human-level control through deep reinforcement learning. nature, 518(7540), 529-533.
위 식에서 갱신의 대상(q-target)은 이며 q-network algorithm도 이와 동일하다. Q-network 알고리즘의 학습 대상(q-target)의 notation은 다음과 같다.
Q-network agent 또한 의 확률로 random action을, 의 확률로 maximum q-value 행동을 취하는 greedy policy를 따른다는 것을 알 수 있다.
Non-stationary target problem은 q-target과 q-value를 동일한 network로 추정했기 때문에 발생하였다. 저자들은 이를 분리하여 각각을 따로 추정하도록 두 개의 network를 둠으로써 해결하였다. Q-target을 추정하는 network를 target network () 로 부르며 일정 주기마다 q-value를 근사하는 network ()의 가중치를 복사하여 사용한다. 만약 매 batch마다 q-value를 근사하는 network를 갱신하고, 32 batch 마다 target network로 가중치를 복사하면 Figure 06과 같은 형태를 띈다.