[CVPR 2019] SpotTune:Transfer Learning through Adaptive Fine-tuning

Paper url: https://arxiv.org/pdf/1811.08737.pdf

Author and affiliation

Introduction

Transfer learning이란 source task를 해결하면서 얻어진 knowledge를 target task로 전이하여 사후적으로 학습하는 개념을 말합니다. 본 논문을 이해함에 있어 필요한 transfer learning의 background knowledge는 다음과 같습니다.

  • Donahue et al.은 pre-trained AlexNet을 이용하여 image의 특성을 추출 후 SVM(support vector machine)의 input으로 활용하였습니다 [1]. Two stage object detection의 시초인 R - CNN이 이에 해당합니다.

  • Yosinski et al.은 pre-trained AlexNet을 target task에 맞게 fine-tuning하는 방법을 사용하였습니다 [2].

    앞에서 제시한 방법인 feature 추출 후 다른 classifier를 사용하는 것 보다 좋은 성능을 보인 바 있습니다.

  • Azizpour et al은 전체 parameter를 갱신하는 것은 target dataset이 작고, model capacity가 큰 경우 overfitting을 경고하였습니다 [3]. 이에 전체 layer가 아닌 특정 layer를 freeze, 소수 layer를 fine-tuning하는 방법이 제안되었습니다.

정리하면 transfer learning은 source task에서 학습한 모델을 불러와 target task의 data set에 맞게 fine-tuning을 진행하는 방식입니다. 이 때 전체 layer를 재학습 하거나 특정 layer만 재학습하는 경우도 있지만 주로 output layer와 가까운 layer들만 fine-tuning하는 방법을 사용합니다.

다만 어떤 layer를 freeze시키고 fine-tuning해야하는지는 사람이 정할 수 밖에 없어서 최적화가 힘든 문제가 남아 있겠죠. 따라서 SpotTune 저자들은 별도의 policy network를 두어 layer를 freeze or fine-tuning하는 것을 사람이 아니라 모델이 결정하게 만들었습니다.

Overview and notations of SpotTune

앞서 말했듯 SpotTune은 fine-tuning strategy를 사람이 아닌 model이 결정하는데 이를 policy network라고 칭합니다. Policy network는 예측 model의 performance를 증가시키는 방향으로 학습을 진행하며 원문은 다음과 같습니다.

Given a pre-trained network model on a source task (e.g., ImageNet pre-trained model), and a set of training examples with associated labels in the target domain, our goal is to create an adaptive fine-tuning strategy that decides, per training example, which layers of the pre-trained model should be fine-tuned (adapted to the target task) and which layers should have their parameters frozen (shared with the source task) during training, in order to improve the accuracy of the model in the target domain.

그렇다면 fine-tuning strategy는 어떻게 작성될까요? 저자들이 제안한 policy network는 target task의 instance마다 strategy를 출력합니다. Instance란 아래 그림의 경우 이미지 한 장을 의미합니다.

다음과 같이 생각하면 쉽게 이해할 수 있습니다.

  1. Source task (ImageNet)를 통해 ResNet을 학습합니다.

  2. Target task (CIFAR-10)에 pre-trained weight를 transfer하여 fine-tuning을 진행합니다. 이 때 ResNet은 trainable network와 fixed network 두 가지로 구성되며 각각은 pre-trained weight로 initialize 됩니다.

  3. Target task의 instance (CIFAR-10의 image 한 장)마다 policy network가 fine-tuning strategy를 학습 및 output으로 출력합니다.

  4. 출력된 fine-tuning strategy에 따라 layer의 freeze여부가 결정되며 freeze layer의 경우 fixed network를 통과하고 fine-tuning layer의 경우 trainable network를 통과합니다.

  5. CIFAR-10에 대해서 model accuracy가 최대화 되도록 policy network가 학습됩니다.

여기서 source task와 target task를 학습하는 model architecture는 ResNet이며 fine-tuning과 freeze의 단위인 layer는 ResNet block을 뜻합니다. 세부적으로 policy network이 어떻게 학습하는지는 후술하도록 하고 우선 notation을 알아보겠습니다.

Notation

SpotTune의 policy network는 fine-tuning block과 frozen block을 결정한다고 설명하였습니다. 저자들은 이를 쉽게 구현하기 위해 두 개의 ResNet을 사용하는데, 하나는 source task의 가중치를 학습하지 않는 original network이며 다른 하나는 fine-tuning을 진행하는 trainable network입니다. 이를 block 단위로 바라보면 policy network의 fine-tuning strategy를 받은 ResNet은 다음 식으로 표기할 수 있습니다.

Gumbel distribution은 극치분포의 일종으로 극치분포란 정규분포에서 최소 또는 최대구간에 밀집된 데이터 분포를 모사하는데 사용됩니다. 이와 자세한 내용은 후술할 Gumbel max trick을 통해서 알아보도록 하겠습니다.

SpotTune의 동작과정을 예시를 통해 확인해보겠습니다.

SpotTune forward & backward pass

모델의 forward pass 과정은 다음과 같습니다.

  1. Target task의 image 한 장을 input으로 받습니다.

  2. 이후 back propagation을 통해 target task의 accuracy가 최대화 되는 방향으로 학습합니다.

그런데 여기서 다음과 같은 의문점이 존재합니다.

다음 절을 통해서 위 의문점을 해결해보도록 하겠습니다.

Optimizing stochastic computation graph

먼저 우리가 흔히 알고 있는 neural net 구조는 Figure 05의 (a)와 같으며 검은색의 forward pass를 따라 정해진 연산(weight를 곱하고 bias를 더한 후 activation function 통과)을 수행하며 그 후 back propagation을 통해서 weight를 update를 진행합니다.

그런데 Figure 05의 (b)와 같이 중간에 stochastic node가 존재할 경우 위와 같은 방법으로 back propagati-on을 진행할 수 없습니다. Stochastic node의 output이 특정 distribution (e.g. Normal distribution)에서 sampling 된 경우가 이에 해당됩니다.

이를 해결하기 위한 방법으로 variational autoencoder에서 활용되는 reparameterization trick이 존재합니다.

Reparameterization trick

앞서 Figure 05의 (b)에서 back propagation을 진행하지 못하는 이유는 바로 stochastic한 성질 때문이었습니다. Reparameterization trick의 핵심은 gradient가 흘러갈 수 있는 deterministic part와 stochastic part를 구분짓는 것입니다. 예를 들어 정규분포를 가정하면 다음과 같이 설명할 수 있는데,

그런데 갑자기 왜 reparameterization trick을 언급했을까요? 본디 continuous한 distribution에 적용이 가능한 trick을 discrete한 경우로 확장시킨 것이 바로 다음에 설명할 Gumbel trick이기 때문입니다.

Gumbel max trick

Reparameterization trick과 동일하게 stochastic node를 분리하여 back propagation으로 학습하는 것이 목표이며 도식화 하면 Figure 07과 같습니다.

SpotTune forward pass summary

즉 여기까지 내용이 SpotTune의 forward pass 과정이며 요약하자면 다음과 같습니다.

  1. Source task를 통해 ResNet을 학습합니다.

  2. Pre-trained ResNet을 복사하여 original block (fixed block)과 trainable block을 구축하며 target task를 학습합니다. 이 때 fine-tuning strategy를 내뱉는 policy network 또한 동시에 학습됩니다.

  3. Policy network는 target task의 instance (image 한 장)마다 output으로 logit값을 내뱉으며 shape은 (residual block의 개수, 2)입니다.

  4. 도출된 fine-tuning strategy대로 target task를 학습하며 policy network는 target task의 accuracy가 증가하는 방향으로 학습합니다.

Softmax relaxation

이후 SpotTune은 classification loss에 더해 k개의 block을 제한을 두는 loss, fine-tuning strategy가 0과 1로 나올 수 있도록 하는 loss를 추가하여 학습을 진행하게 됩니다. 해당 내용은 중요한 내용이 아닐 뿐더러, 필자가 아직 정확하게 동의하기 어려운 내용이여서 설명은 생락하였습니다.

Experiments

논문의 저자는 총 두가지 category의 experiment를 진행하였으며 여기서는 그 중 하나만 설명하였습니다. 먼저 datasets은 아래와 같이 5가지이며 training은 source task로 evaluation은 target task로 설정되었습니다.

Baseline

비교 대상은 다음과 같습니다.

  1. Standard fine-tuning : 모든 파라미터 재학습

  2. Feature extractor : add one classifier layer

  3. Stochastic fine-tuning : randomly sample 50% of the blocks of the pre-trained network to fine-tune

  4. Fine-tuning last-k block: k = 1,2,3

  5. Fine-tuning ResNet 101: Spottune의 경우 ResNet 50을 사용

  6. Random policy : always fine-tunes the last three layers and random decision for other layers

결과는 아래 표를 통해서 확인할 수 있습니다.

한 가지 주목할만한 점은 SpotTune의 실험 model은 ResNet-50이었다는 것입니다. 중간의 fine-tuning ResNet-101의 경우 ResNet-50에 비해 model capacity가 더 높음에도 불구하고 SpotTune으로 transfer learning을 한 것이 WikiArt dataset을 제외하고 높음을 확인할 수 있습니다. 또한 일반적으로 사용하는 fine-tuning last-k block (초기 layer는 고정, output layer와 가까운 k개의 layer만 fine-tuning)에 비해서도 높은 성능을 보여주어 유의미한 성능 향상이 있음을 알 수 있습니다.

기존의 transfer learning의 경우 어떤 layer를 고정 시키고 fine-tuning 시켜야 최적의 성능을 낼 수 있는지를 사람이 찾았다면 SpotTune은 학습 과정에서 이를 자동화 시킨 방법입니다. Target task에 최적화 되도록 fine-tuning strategy를 결정하는 policy network를 도입한 것이지요. 위 실험을 통해서 computer vision에서 좋은 성능을 낸 것을 확인할 수 있습니다.

다만 computer vision 이외의 task에서도 효과를 보일지는 의문이 들긴 합니다. Convolutional neural network의 특성상, layer의 위치에 따라 추출하는 feature의 특징이 다르기 때문입니다. 만약 일반적인 feed forward neural network나 recurrent neural network 같이 layer의 위치가 무관한 architecture의 경우 적합한 방법은 아닐 것으로 생각이 됩니다.

Reference

[1] Girshick, R., Donahue, J., Darrell, T., & Malik, J. (2014). Rich feature hierarchies for accurate object detection and semantic segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 580-587).

[2] Yosinski, J., Clune, J., Bengio, Y., & Lipson, H. (2014). How transferable are features in deep neural networks?. In Advances in neural information processing systems (pp. 3320-3328).

[3] Azizpour, H., Razavian, A. S., Sullivan, J., Maki, A., & Carlsson, S. (2015). Factors of transferability for a generic convnet representation. IEEE transactions on pattern analysis and machine intelligence, 38(9), 1790-1802.

http://jaejunyoo.blogspot.com/2018/09/

https://data-newbie.tistory.com/263

https://hulk89.github.io/machine%20learning/2017/11/20/reparametrization-trick/

https://www.youtube.com/watch?time_continue=1227&v=ty3SciyoIyk&feature=emb_title

Last updated