[CVPR 2019] SpotTune:Transfer Learning through Adaptive Fine-tuning

Paper url: https://arxiv.org/pdf/1811.08737.pdf

Author and affiliation

Figure 01 : paper snapshot

Introduction

Transfer learning이란 source task를 해결하면서 얻어진 knowledge를 target task로 전이하여 사후적으로 학습하는 개념을 말합니다. 본 논문을 이해함에 있어 필요한 transfer learning의 background knowledge는 다음과 같습니다.
  • Donahue et al.은 pre-trained AlexNet을 이용하여 image의 특성을 추출 후 SVM(support vector machine)의 input으로 활용하였습니다 [1]. Two stage object detection의 시초인 R - CNN이 이에 해당합니다.
Figure 02 : R-CNN architecture
  • Yosinski et al.은 pre-trained AlexNet을 target task에 맞게 fine-tuning하는 방법을 사용하였습니다 [2].
    앞에서 제시한 방법인 feature 추출 후 다른 classifier를 사용하는 것 보다 좋은 성능을 보인 바 있습니다.
  • Azizpour et al은 전체 parameter를 갱신하는 것은 target dataset이 작고, model capacity가 큰 경우 overfitting을 경고하였습니다 [3]. 이에 전체 layer가 아닌 특정 layer를 freeze, 소수 layer를 fine-tuning하는 방법이 제안되었습니다.
정리하면 transfer learning은 source task에서 학습한 모델을 불러와 target task의 data set에 맞게 fine-tuning을 진행하는 방식입니다. 이 때 전체 layer를 재학습 하거나 특정 layer만 재학습하는 경우도 있지만 주로 output layer와 가까운 layer들만 fine-tuning하는 방법을 사용합니다.
다만 어떤 layer를 freeze시키고 fine-tuning해야하는지는 사람이 정할 수 밖에 없어서 최적화가 힘든 문제가 남아 있겠죠. 따라서 SpotTune 저자들은 별도의 policy network를 두어 layer를 freeze or fine-tuning하는 것을 사람이 아니라 모델이 결정하게 만들었습니다.

Overview and notations of SpotTune

앞서 말했듯 SpotTune은 fine-tuning strategy를 사람이 아닌 model이 결정하는데 이를 policy network라고 칭합니다. Policy network는 예측 model의 performance를 증가시키는 방향으로 학습을 진행하며 원문은 다음과 같습니다.
Given a pre-trained network model on a source task (e.g., ImageNet pre-trained model), and a set of training examples with associated labels in the target domain, our goal is to create an adaptive fine-tuning strategy that decides, per training example, which layers of the pre-trained model should be fine-tuned (adapted to the target task) and which layers should have their parameters frozen (shared with the source task) during training, in order to improve the accuracy of the model in the target domain.
그렇다면 fine-tuning strategy는 어떻게 작성될까요? 저자들이 제안한 policy network는 target task의 instance마다 strategy를 출력합니다. Instance란 아래 그림의 경우 이미지 한 장을 의미합니다.
Figure 03 : SpotTune architecture 01
다음과 같이 생각하면 쉽게 이해할 수 있습니다.
  1. 1.
    Source task (ImageNet)를 통해 ResNet을 학습합니다.
  2. 2.
    Target task (CIFAR-10)에 pre-trained weight를 transfer하여 fine-tuning을 진행합니다. 이 때 ResNet은 trainable network와 fixed network 두 가지로 구성되며 각각은 pre-trained weight로 initialize 됩니다.
  3. 3.
    Target task의 instance (CIFAR-10의 image 한 장)마다 policy network가 fine-tuning strategy를 학습 및 output으로 출력합니다.
  4. 4.
    출력된 fine-tuning strategy에 따라 layer의 freeze여부가 결정되며 freeze layer의 경우 fixed network를 통과하고 fine-tuning layer의 경우 trainable network를 통과합니다.
  5. 5.
    CIFAR-10에 대해서 model accuracy가 최대화 되도록 policy network가 학습됩니다.
여기서 source task와 target task를 학습하는 model architecture는 ResNet이며 fine-tuning과 freeze의 단위인 layer는 ResNet block을 뜻합니다. 세부적으로 policy network이 어떻게 학습하는지는 후술하도록 하고 우선 notation을 알아보겠습니다.

Notation

  • ll
    -th residual block in a pre-trained ResNet model
먼저 ResNet의 l번째 residual block은 다음 식과 같이 표기할 수 있습니다. Skip-connection은
xl1x_{l-1}
로 표기되어 있으며 본래의 weight 를 거치는 부분은
Fl(xl1)F_l(x_{l-1})
로 표기됩니다.
𝑥𝑙=F𝑙(𝑥𝑙1)+𝑥𝑙1𝑥_𝑙=F_𝑙 (𝑥_{𝑙−1} ) + 𝑥_{𝑙−1}
  • Original block (
    F𝑙F_𝑙
    ) & trainable block (
    F^l\hat{F}_l
    )
SpotTune의 policy network는 fine-tuning block과 frozen block을 결정한다고 설명하였습니다. 저자들은 이를 쉽게 구현하기 위해 두 개의 ResNet을 사용하는데, 하나는 source task의 가중치를 학습하지 않는 original network이며 다른 하나는 fine-tuning을 진행하는 trainable network입니다. 이를 block 단위로 바라보면 policy network의 fine-tuning strategy를 받은 ResNet은 다음 식으로 표기할 수 있습니다.
𝑥𝑙=I𝑙(x)F^ 𝑙(𝑥𝑙1)+(1I𝑙(x))F𝑙(𝑥𝑙1)+𝑥𝑙1𝑥_𝑙=I_𝑙 (x) \hat{F} _𝑙 (𝑥_{𝑙−1} ) + (1 − I_𝑙 (x)) F_𝑙 (𝑥_{𝑙−1} ) + 𝑥_{𝑙−1}
위 식에서
Il(x)I_l(x)
는 binary random variable로써 residual block이 freeze
(Il(x)=0)(I_l(x) = 0)
될지 fine-tuning
(Il(x)=1)(I_l(x) = 1)
될지 결정합니다. Trainable block
(F^l)(\hat{F}_l)
FlF_l
의 weight를 초기값으로 할당받으며 target task를 학습하며 최적화 됩니다.
  • Binary random variable
    Il(x)I_l(x)
    는 standard Gumbel distribution을 따름
G=log(log(U)) with UUnif[0,1]G =−log(−log(U)) \ with \ U \thicksim Unif [0,1]
Gumbel distribution은 극치분포의 일종으로 극치분포란 정규분포에서 최소 또는 최대구간에 밀집된 데이터 분포를 모사하는데 사용됩니다. 이와 자세한 내용은 후술할 Gumbel max trick을 통해서 알아보도록 하겠습니다.
SpotTune의 동작과정을 예시를 통해 확인해보겠습니다.

SpotTune forward & backward pass

Figure 04 : SpotTune architecture 02
우선 4개의 residual block을 갖고 있는 ResNet이 있다고 가정하겠습니다. Figure 04는 source task에 대한 pre-trained ResNet을 복사하여 trainable block
(F^l)(\hat{F}_l)
을 생성 후 target task를 학습하는 과정입니다. 앞에서 설명한 대로 fixed ResNet (연분홍색)과 trainable ResNet (적색)두 가지가 있는 것을 확인할 수 있습니다. 여기서 target task의 학습 대상은 trainable block과 policy network입니다.
모델의 forward pass 과정은 다음과 같습니다.
  1. 1.
    Target task의 image 한 장을 input으로 받습니다.
  2. 2.
    Policy network가 output
    (logit=log ai)(logit =log\ a_{i})
    을 출력하며 이 예제의 경우 (4, 2) 행렬을 출력합니다.
  3. 3.
    (log ai+Gi)(log\ a_{i} + G_{i})
    argmaxargmax
    값으로 fine-tuning strategy를 구성합니다.
  4. 4.
    이후 back propagation을 통해 target task의 accuracy가 최대화 되는 방향으로 학습합니다.
그런데 여기서 다음과 같은 의문점이 존재합니다.
  1. 1.
    왜 policy network의 output (logit)만을 사용하지 않고 random variable과 결합된
    (log ai+Gi)(log\ a_{i} + G_{i})
    argmaxargmax
    값으로 fine-tuning strategy를 구성하는가?
  2. 2.
    argmaxargmax
    operation은 미분이 불가할텐데 이를 어떻게 back propagation 할 수 있는가?
다음 절을 통해서 위 의문점을 해결해보도록 하겠습니다.

Optimizing stochastic computation graph

Figure 05 : Deterministic node & stochastic node
먼저 우리가 흔히 알고 있는 neural net 구조는 Figure 05의 (a)와 같으며 검은색의 forward pass를 따라 정해진 연산(weight를 곱하고 bias를 더한 후 activation function 통과)을 수행하며 그 후 back propagation을 통해서 weight를 update를 진행합니다.
그런데 Figure 05의 (b)와 같이 중간에 stochastic node가 존재할 경우 위와 같은 방법으로 back propagati-on을 진행할 수 없습니다. Stochastic node의 output이 특정 distribution (e.g. Normal distribution)에서 sampling 된 경우가 이에 해당됩니다.
이를 해결하기 위한 방법으로 variational autoencoder에서 활용되는 reparameterization trick이 존재합니다.

Reparameterization trick

앞서 Figure 05의 (b)에서 back propagation을 진행하지 못하는 이유는 바로 stochastic한 성질 때문이었습니다. Reparameterization trick의 핵심은 gradient가 흘러갈 수 있는 deterministic part와 stochastic part를 구분짓는 것입니다. 예를 들어 정규분포를 가정하면 다음과 같이 설명할 수 있는데,
평균이
μ\mu
이고 분산이
σ2\sigma^2
인 정규분포에서 sampling한 random variable
XX
는 평균이 0이고 분산이 1인 표준정규분포에서 sampling한 random variable
ZZ
μ\mu
를 더하고
σ\sigma
를 곱한것과 완벽히 동치를 이룹니다.
Figure 06 : Reparameterization trick
이를 통해 더이상 optimization 대상인 distribution의 parameter
ϕ\phi
는 분포에 dependent하지 않게되며 (random variable
ZZ
는 표준 정규분포를 따르며
μ\mu
σ\sigma
를 update해도 distribution과는 독립적) back propagation을 통해 update가 가능해집니다.
그런데 갑자기 왜 reparameterization trick을 언급했을까요? 본디 continuous한 distribution에 적용이 가능한 trick을 discrete한 경우로 확장시킨 것이 바로 다음에 설명할 Gumbel trick이기 때문입니다.

Gumbel max trick

Reparameterization trick과 동일하게 stochastic node를 분리하여 back propagation으로 학습하는 것이 목표이며 도식화 하면 Figure 07과 같습니다.
Figure 07 : Gumbel max trick
위와 마찬가지로 gradient가 흘러가는 deterministic part와 back propagation을 위해 분리되어진 stochastic part로 구분되어 있는 것을 확인할 수 있습니다. Deterministic part의
log ailog\ a_i
는 앞 절에서 언급한 policy network의 output인 logit을 말하며 stochastic part는 standard Gumbel distribution을 따르는 random variable
GiG_i
를 의미합니다.
즉 첫번째 의문이였던 fine-tuning strategy를
(log ai+Gi)(log\ a_{i} + G_{i})
argmaxargmax
값으로 구성하는 이유를 알 수 있습니다. 두 가지 component는 사실 하나의 값에서 분리되어 나온 것으로 바라보는것이 마땅하며(network 학습을 위해), Gumbel distribution은 극치 분포의 일종으로 극치분포가 정규분포에서 최소와 최대 값에 분포되어 있는 데이터를 모사하는 것으로 미루어 보아 freeze or fine-tuning이란 이산적인 값을 표현할 수 있도록 policy network이 학습하도록 유도하는 것으로 판단됩니다.

SpotTune forward pass summary

즉 여기까지 내용이 SpotTune의 forward pass 과정이며 요약하자면 다음과 같습니다.
  1. 1.
    Source task를 통해 ResNet을 학습합니다.
  2. 2.
    Pre-trained ResNet을 복사하여 original block (fixed block)과 trainable block을 구축하며 target task를 학습합니다. 이 때 fine-tuning strategy를 내뱉는 policy network 또한 동시에 학습됩니다.
  3. 3.
    Policy network는 target task의 instance (image 한 장)마다 output으로 logit값을 내뱉으며 shape은 (residual block의 개수, 2)입니다.
  4. 4.
    Fine-tuning strategy는 network의 ouput과 standard Gumbel distribution에서 sampling한 random variable을 더한 값
    (log ai+Gi)(log\ a_{i} + G_{i})
    argmaxargmax
    값으로 결정합니다(Figure 04에선 4개의 vector가 나옴).
  5. 5.
    도출된 fine-tuning strategy대로 target task를 학습하며 policy network는 target task의 accuracy가 증가하는 방향으로 학습합니다.

Softmax relaxation

이제 backward pass를 알아보겠습니다. 두번째 의문이였던 미분 불가능한
argmaxargmax
를 처리하는 간단한 방법을 알아보겠습니다.
Figure 08: argmax & softmax
Figure 08과 같이 argmax를 softmax로 변환하여 gradient가 흘러갈 수 있도록 해결해줍니다. 이를 softmax relaxation이라고 하며 Figure 08의 (b)의 수식에서 분모 및 분자에 들어가는
λ\lambda
의 경우 softmax값을 변화시키는 역할을 하며 0인 경우
argmaxargmax
와 완전히 동일해집니다.
이후 SpotTune은 classification loss에 더해 k개의 block을 제한을 두는 loss, fine-tuning strategy가 0과 1로 나올 수 있도록 하는 loss를 추가하여 학습을 진행하게 됩니다. 해당 내용은 중요한 내용이 아닐 뿐더러, 필자가 아직 정확하게 동의하기 어려운 내용이여서 설명은 생락하였습니다.

Experiments

논문의 저자는 총 두가지 category의 experiment를 진행하였으며 여기서는 그 중 하나만 설명하였습니다. 먼저 datasets은 아래와 같이 5가지이며 training은 source task로 evaluation은 target task로 설정되었습니다.

Baseline

비교 대상은 다음과 같습니다.
  1. 1.
    Standard fine-tuning : 모든 파라미터 재학습
  2. 2.
    Feature extractor : add one classifier layer
  3. 3.
    Stochastic fine-tuning : randomly sample 50% of the blocks of the pre-trained network to fine-tune
  4. 4.
    Fine-tuning last-k block: k = 1,2,3
  5. 5.
    Fine-tuning ResNet 101: Spottune의 경우 ResNet 50을 사용
  6. 6.
    Random policy : always fine-tunes the last three layers and random decision for other layers
  7. 7.
    L2L^2
    -SP : ICML 2018, recently proposed SOTA regularization method for fine-tuning
결과는 아래 표를 통해서 확인할 수 있습니다.
한 가지 주목할만한 점은 SpotTune의 실험 model은 ResNet-50이었다는 것입니다. 중간의 fine-tuning ResNet-101의 경우 ResNet-50에 비해 model capacity가 더 높음에도 불구하고 SpotTune으로 transfer learning을 한 것이 WikiArt dataset을 제외하고 높음을 확인할 수 있습니다. 또한 일반적으로 사용하는 fine-tuning last-k block (초기 layer는 고정, output layer와 가까운 k개의 layer만 fine-tuning)에 비해서도 높은 성능을 보여주어 유의미한 성능 향상이 있음을 알 수 있습니다.
기존의 transfer learning의 경우 어떤 layer를 고정 시키고 fine-tuning 시켜야 최적의 성능을 낼 수 있는지를 사람이 찾았다면 SpotTune은 학습 과정에서 이를 자동화 시킨 방법입니다. Target task에 최적화 되도록 fine-tuning strategy를 결정하는 policy network를 도입한 것이지요. 위 실험을 통해서 computer vision에서 좋은 성능을 낸 것을 확인할 수 있습니다.
다만 computer vision 이외의 task에서도 효과를 보일지는 의문이 들긴 합니다. Convolutional neural network의 특성상, layer의 위치에 따라 추출하는 feature의 특징이 다르기 때문입니다. 만약 일반적인 feed forward neural network나 recurrent neural network 같이 layer의 위치가 무관한 architecture의 경우 적합한 방법은 아닐 것으로 생각이 됩니다.

Reference

[1] Girshick, R., Donahue, J., Darrell, T., & Malik, J. (2014). Rich feature hierarchies for accurate object detection and semantic segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 580-587).
[2] Yosinski, J., Clune, J., Bengio, Y., & Lipson, H. (2014). How transferable are features in deep neural networks?. In Advances in neural information processing systems (pp. 3320-3328).
[3] Azizpour, H., Razavian, A. S., Sullivan, J., Maki, A., & Carlsson, S. (2015). Factors of transferability for a generic convnet representation. IEEE transactions on pattern analysis and machine intelligence, 38(9), 1790-1802.
\checkmark
논문의 concept 및 idea 위주로 정리하여 자세한 수식이나 내용에 오류가 있을 수 있습니다.