[CVPR 2019] Class-Balanced Loss Based on Effective Number of Samples

Paper url : https://arxiv.org/abs/1901.05555

Author and affiliation

Background

우리가 흔히 classification 문제를 풀 때 real data set은 long-tailed distribution을 띄는 경향을 보인다. Long-tail distribution이란 class 별로 sample size가 다르며, 소수 sample을 보유한 class가 많은 경우를 말한다.

이러한 분포를 띄는 data set, 즉 class imbalance가 심한 경우 deep learning의 성능이 좋지 않은 경향을 보인다 [1]. 따라서 본 논문은 class imbalance 상황에서 neural net의 성능을 높이고자 하였다.

Two strategies in class imbalanced problem

Class imbalance problem을 다루는 방법은 두 가지로 re-samplingre-weighting(=cost-sensitive learning)이 존재한다. 먼저 re-sampling의 경우 sampling을 통해서 class imbalance를 해결하고자 한다. Sampling을 통해서 데이터 균형을 맞추는 것은 다시 두 가지로 나뉘는데, minority class를 over sampling하여 majority와의 균형을 맞추는 것과 반대로 majority class를 under sampling하여 균형을 맞추는 방법이 있다. 통상적으로 over sampling method의 성능이 높다고 알려져 있다 [2]. 아래 그림은 관련 내용이다.

참고로 위 그림의 two-phase training이란 balanced data (sampling 기법으로 균형을 맞춘)로 학습 후 output layer만 original data로 fine-tuning 한 것을 말한다.

다음으로 두 번째 방법인 re-weighting(이하 cost-sensitive learning)은 data를 변형하는 것이 아닌 loss에 차등(weighting)을 주어 학습하는 방법을 말하며 논문에서 제안한 class balanced loss는 이에 해당한다.

Cost-sensitive learning은 sample별로 다른 weight를 주는 방법과 class 별로 다른 weight를 주는 경우로 나눌 수 있다. 먼저 sample 별로 다른 error를 반영하는 loss function으로는 대표적으로 focal loss가 존재한다[3].

Focal loss

논문의 내용을 잠시 벗어나서 focal loss를 다뤄보자. Focal loss의 핵심은 쉬운(=confidence가 높은) sample일 수록 더 적은 loss를 반영하여 학습하는 것인데 수식을 살펴보면 아래와 같다.

일반적인 cross entropy를 위와 같이 정의하면 focal loss는 아래와 같다.

기존의 cross entropy에 추가 된 term은 (1pt)r(1-p_{t})^r이다. y=1y=1인 경우를 예로 들면, 학습이 잘 된 sample의 경우 model의 logit(pt)(p_{t})값이 1에 가까울 것이다. 따라서 focal loss term인 (1pt)r(1-p_{t})^r값이 0에 가깝게 형성될 것이고 cross entropy를 거의 반영하지 않게 된다. 학습이 잘 되지 않은 sample의 경우의 logit은 0에 가까운 값을 갖는다. 따라서 focal loss term은 1에 가깝게 형성되어 기존의 cross entropy 값을 그대로 반영하게 된다. 즉 focal loss는 sample 별로 다른 weight를 주는 cost-sensitive learning이다.

다시 논문 내용으로 돌아와서, cost-sensitive learning 중 class 별로 weight를 달리 주는 방법은 majority class와 minority class의 error를 다르게 반영하여 불균형을 완화시키는 것이 목표이다. 대표적으로 class size 비율의 inverse 값을 weight로 삼아서 학습한다. 예를 들어, majority class와 minority class간의 sample size 비율이 2:1라면 minority class는 loss를 그대로 반영하는 대신 majority class의 loss는 0.5만큼 감소시켜 반영한다.

앞서 언급한 방법들은 모두 단점을 갖고 있는데, re-sampling에서 under sampling의 경우 data의 손실이 발생하며, over sampling의 경우 sample size가 커져 time cost가 증가하며 overfitting의 가능성이 있다. Class별로 weighting을 주는 방법의 경우 under sampling과 유사한 문제가 발생한다. 단순 class proportion의 inverse 값으로 weight를 사용하기 때문에 model의 majority class 분류능력을 감소시킬 가능성이 존재한다. 아래의 그림을 보면 이해하기 쉽다.

먼저 majority & minority class에 weight를 주지 않고 학습할 경우 검은색 decision boundary가 형성된다고 가정하자. 만약 class 비율의 역수를 weight로 주어서 학습한다면 붉은색 점선의 decision boundary가 형성된다. 예를 들어 다수와 소수 class 데이터 비율이 1,000:1인 경우 weight는 각각 0.001, 1로 반영하게 된다. 분류 모델은 일부 majority sample의 손해를 보더라도 minority class의 margin을 더 크게 주게되기 때문에 majority class 분류능력을 감소시킬 수 있다. 이와 같은 문제는 weight가 sample size에 따라 선형적으로 변하기 때문에 발생한다.

논문의 저자들은 이러한 방법을 개선시키고자 class balanced loss(이하 CB loss)를 제안한다. CB loss의 핵심은 sample size를 사용하는 것이 아닌 effective number에 기반하여 weight를 주고자 한다. 그렇다면 effective number란 무엇인가?

Effective number of samples

Sample의 수가 많으면 어떤 현상이 발생하는가? 중복되는 sample이 늘어나고 다른 sample로 인해 충분히 표현되어지는, 즉 영향력이 없는 sample의 수도 증가할 것이다. Data sample의 effective number란 N개의 data가 있을 때 중복, 유사한 sample을 제외한 영향력 있는 sample들의 개수를 뜻한다.

간단한 예로, 1부터 10까지 5번 복원 추출을 진행했을 때 [1,8,2,9,3]이 추출될 경우 sample size와 effective number는 모두 5이다. 반면 [1,1,2,2,3]이 추출될 경우 sample size는 5이지만 effective number는 3으로 계산된다.

그렇다면 우리가 마주하는 실제 데이터 상에서 영향력이 있는 sample과 없는 sample을 구분하는 방법에 대한 의문이 든다. 저자들은 random covering 개념이라는 말을 사용한다. 이는 전체 데이터(N)에서 n번째 sample을 추출할 경우 p의 확률로 sampled data와 중복(=coverage 안에 포함)되며 1-p의 확률로 새로운 sample이 추출된다고 가정한다. 그림으로 표현하면 아래와 같다.

Effective number는 이러한 새로운 sample이 나올 경우 counting 되며 식은 아래와 같다.

En=pEn1+(1p)(En+1)E_{n} = pE_{n-1}+(1-p)(E_{n}+1)

여기서 기존에 추출된 데이터(sampled data)의 coverage에 포함 될 확률 p는 n-1번째 effective number를 전체 sample의 개수 N으로 나눈 값이며 아래와 같다.

p=En1Np = {E_{n-1} \over N}

즉 앞선 예제에서 6번째 sample을 뽑을 때 확률 p는 3/10 = 0.3이다.

Class balanced loss

위에서 작성한 effective number 공식에 p의 값을 대입하여 정리해보면 아래와 같다.

En=1+N1NEn1E_{n} = 1+{N-1 \over N}E_{n-1}

그 후 β=(N1)N\beta = {(N-1) \over N}로 설정하고, 논문에 나온대로 En1=(1βn1)(1β)E_{n-1}={(1-\beta^{n-1}) \over (1-\beta)} 를 들고있다고 가정하면 effective number는 최종적으로 아래와 같이 정리된다.

En=1βn1βE_n = {1-\beta^n \over 1-\beta}

구해진 effective number의 경우 sample size가 커질 수록 값이 커진다. 즉 majority class일 수록 큰 값을 갖고 minority class일 수록 작은 값을 갖게 된다. 이는 class imbalance 문제를 다루기에 알맞지 않음으로 effective number의 역수를 weight로 삼고 이를 class balanced term이라고 부른다. 따라서 최종적으로 논문에서 제안하는 class balanced loss는 아래와 같다.

우리가 흔히 사용하는 cross entropy, soft max, focal loss에 class 별로 class balance term을 곱해서 학습을 진행한다.

Effect of class balanced term

앞서 우리는 sample size의 역수를 weight로 삼을 경우 선형적으로 loss를 반영하기 때문에 majority class의 분류 능력을 침해한다고 하였다. 그렇다면 effective number를 기반으로 구해진 class balanced term은 이 문제를 해결할 수 있을까?

위 그림을 보면 sample size가 아무리 커지더라도 beta 값에 따라서 일정 수준 이상이 되면 더 이상 class balanced term이 감소하지 않고 일정 값을 유지하는 것을 볼 수 있다. 만약 beta를 1로 둘 경우 sample size에 따라 비례해서 감소하는 일반적인 방법이 된다.

Experiments

실험은 총 5가지 data set을 사용하였으며 CIFAR-10, 100의 경우 인위적으로 imbalance data를 만들어 학습하였다(test는 기존 데이터와 동일). 아래 도표의 imbalance factor는 largest class의 sample size를 smallest class의 sample size로 나눈 값을 의미한다.

아래 도표는 CIFAR-10,100에 대한 실험 결과이다.

다음은 나머지 3개 data set에 대한 실험 결과이다.

아래 그래프는 iNaturalist 2018 데이터 셋에서 epoch별 Top-1 error 그래프이다. 보면 60 epoch 이후로 기존 soft max loss보다 좋은 성능을 보이는 것을 확인할 수 있다. 이는 class balanced term의 영향인데, sample 숫자가 클 수록 class balanced term의 값은 매우 작아진다. 즉 기존의 loss에 매우 작은 값을 곱해서 모델이 학습을 진행하기 때문에 모델의 성능을 보장하기 위해서는 많은 epoch이 필요한 것으로 판단된다.

Class imbalance 상황은 우리가 일상생활에서 쉽게 접할 수 있는 문제이다. CB Loss의 등장 이후 비슷한 시기에 더 좋은 성능을 보인 Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss (LDAM-DRW, NIPS '19) 논문도 존재하니 참고하면 좋을 듯 하다.

Reference

[1] Van Horn, G., & Perona , P. (2017). The devil is in the tails: Fine grained classification in the wild. arXiv preprint arXiv:1709.01450

[2] Buda, M., Maki, A., & Mazurowski , M. A. (2018). A systematic study of the class imbalance problem in convolutional neural networks. Neural Networks, 106, 249 259

[3] Lin, T. Y., Goyal , P., Girshick , R., He, K., & Dollár , P. (2017). Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision (pp. 2980 2988).

\checkmark 논문의 concept 및 idea 위주로 정리하여 자세한 수식이나 내용에 오류가 있을 수 있습니다.

\checkmark CB Loss 구현 github url : https://github.com/yjchoi-95/Class-Balanced-Loss-tf2.1

Last updated