[ICLR 2020] Distance-Based Learning from Errors for Confidence Calibration
Paper url: https://arxiv.org/pdf/1912.01730.pdf
Last updated
Paper url: https://arxiv.org/pdf/1912.01730.pdf
Last updated
오늘날 neural net은 예전보다 정확도는 많이 향상되었지만 calibration은 좋지 않은 편에 속한다. Calibration이란 모형의 출력값이 실제 confidence를 반영하는 것을 말하며 calibrated confidence라고도 한다. 예를 들어 COVID-19의 양성과 음성을 분류하는 task가 주어졌다고 가정해보자. 환자 A에 대한 모형의 출력이 0.8이 나왔을 때, 실제로도 80% 확률로 양성이라면 calibration이 잘 이루어졌다고 본다. 즉 모형의 출력값과(confidence) 실제 확률을 동일하게 만드는 것이 model calibration의 목적이다.
Calibration이 잘 이루어졌다는 것을 어떻게 입증할 수 있을까? 만약 모형의 출력값이 실제 confidence를 반영한다면 confidence와 accuracy가 일치해야한다. 모델이 0.8의 confidence로 예측한 sample들의 경우 0.8의 accuracy를 가진다면 confidence와 실제 확률은 동일하다고 볼 수 있다. 강아지와 고양이를 분류하는 학습된 모델이 있을 때, 모델의 confidence가 0.8인 sample들을 모아서 accuracy를 측정할 경우 0.8에 근사한 값을 가져아한다는 것이다.
그러나 앞서 말했듯 오늘날 neural net은 over confident한 문제를 갖고 있다. 아래 그림은 CNN의 초창기 모델인 LeNet-5와 그에 비해 많은 parameter를 갖고 있는 ResNet에 대한 confidence-accuracy chart이다. 도표의 해석은 LeNet-5의 경우 CIFAR-100 data set에 대해서 모델의 confidence가 0.4인 sample들을 모아 accuracy를 측정해보았더니 0.4~0.5 수준이라고 판단한다. ResNet의 경우 model capacity가 낮은 LeNet-5와 달리 대부분의 confidence가 1에 가깝께 쏠려있는 것을 확인할 수 있으며, confidence 별로 accuracy와의 gap 또한 큰 것을 볼 수 있는데 이러한 현상을 over confident라고 한다.
Model confidence가 실제 확률을 반영하는 것이 왜 중요한가? 이는 현업에서 neural net이 인간을 대체할 때 악영향을 끼칠 수 있기 때문이다. 가령 의료 업계에서 질병을 진단하는데 60% 확률과 90% 확률은 엄연히 다르나 over confident model을 바탕으로 처방을 내린다면 과한 진료를 진행하게 될 것이다. 따라서 model calibration 향상을 위한 노력은 다양하게 존재해 왔으며 대표적으로 label smoothing, mix up 등이 있다.
먼저 label smoothing이란 label을 0과 1로 두어서 학습하는 것이 아닌 smooth하게 부여하여 과도하게 학습하는 것을 막는 방법이다. 강아지와 고양이에 대한 label을 0과 1로 할당하는 것이 아닌 0.1, 0.9로 두어 학습하는 것이 그 예이다. 이는 모델의 예측값이 극단에 치닫지 않게 해주어 regularization에 도움을 주면서 model generalization과 calibration에 도움이 된다.
그러나 위 두가지 방법 모두 objective function이 confidence estimation을 목표로 삼고있지 않다. 다른말로 classification의 성능을 높이는 방법의 부가적인 산물이 calibration에 도움을 준 것이지 직접적으로 calibration을 target으로 삼고있지 않다는 것이다. 이러한 confidence scoring을 직접적으로 학습에 활용하는 방법으로는 temperature scaling이 존재하지만 해당 방법의 경우 classification과 calibration 두 가지 task로 데이터를 분할해야하기 때문에 data의 분배량에 따라 개별 task의 성능이 trade-off 성질을 띄게 되는 단점이 있다. 따라서 본 논문은 model calibration을 직접적으로 학습의 목표로 삼으며 앞서 말한 단점을 개선하는 것을 목표로 삼는다.
요약하면 support set으로 class의 중심점을 구하고, query sample이 중심점과의 distance를 줄여나가며 classification model은 최적의 mapping을 학습한다. 그 후 classification에서 오분류된 데이터를 이용하여 confidence model을 학습한다. Confidence model의 역할은 ground-truth가 없는 테스트 데이터에 대해서도 confidence를 추정할 수 있게 만드는 것이며 자세한 설명은 후술하겠다.
학습 단계를 먼저 살펴보면, 일반적인 neural net의 학습 방식인 batch 단위 학습과는 달리 매 update마다 K-shot, N-way sample을 사용하는 episodic training 방식을 따른다. Episodic training이란 매 episode마다 전체 class M에서 N개의 sampled class를 가져온 후(반드시 N=M일 필요는 없다), N개의 class들을 K개의 데이터를 지닌 support set과 query set으로 분리한다. 즉 episode 마다 query와 support set은 바뀌며 충분히 많은 episode가 진행되면 전체 data를 사용하여 학습할 수 있다. Support set과 query set의 notation은 다음과 같다.
Episodic training의 loss function은 아래와 같으며 자세한 사항은 아래에서 설명을 이어가겠다.
Softmax output을 개념적으로 바라보자. 만약 개와 고양이를 분류하는 문제라면 우리는 두 개의 element를 보유한 vector를 output으로 얻게될 것이다. 각각의 element는 class에 해당하는 확률로 표기가 되는데, 위 수식을 사용하면 해당 확률은 sample과 class 중심점과의 거리로써 측정이 되는 것이다.
반대로 class와의 중심점과 distance가 멀면 해당 class의 softmax output은 0에 가까운 값을 갖게 된다. 이를 통해 학습을 거듭하면서 embedding space 상에서 inter-class distance는 커지고 intra-class distance는 작아지게 된다. 따라서 위 loss를 사용하면 같은 class 안의 sample들은 비슷한 공간에 mapping 되며 다른 class의 sample들은 밀어내는 효과를 얻을 수 있다.
Support set, query set 및 prototypical loss의 역할에 대해서 MNIST 예제를 통해 제대로 이해해보자. 10개 class 각각에 대해서 1000개를 support set으로, 나머지 5000개를 query set으로 두었다고 가정하자. 이는 1000 shot, 10-way samples setting에 해당한다.
즉 위 prototypical loss를 사용하면 일반적인 분류기와 동일한 형태로 출력을 낼 수 있으면서 개별 query sample들을 class 중심점으로 모이게 하는 효과를 가져올 수 있다. 이러한 과정은 episode가 반복되면서 query set과 support set에 해당하는 sample들이 랜덤하게 섞이면서 모든 데이터를 활용한다.
다만 위에서 제시한 방법의 가장 큰 문제점은 distance를 측정할 때 test sample에 대해서 ground-truth label이 필요하다는 것이다. 우리가 실제로 풀어야 하는 문제는 test sample의 label이 없는 경우다. 따라서 논문의 저자들은 classification을 학습할 때 confidence를 측정할 수 있는 모델을 동시에 학습하는 방법을 제안하며 이를 joint training으로 표현한다.
여기까지 읽었을 때 '도대체 이게 무슨 말이며 왜 이런 과정이 필요하지?' 라는 생각이 들 수 있는데, 후술할 내용들을 읽고나서 다시 돌아온다면 이해가 쉬울 것이다.
먼저 confidence를 추정하는 방법부터 살펴보자.
정리하면 confidence model은 오분류 된 sample에 대해서 sigma를 키움으로써 n번의 sampling 시 일관성을 유지하지 못하게 만들어 이를 평균 내었을 때 어느 class에도 속하지 못하도록 confidence를 낮추게 된다. 위 방법을 사용할 경우 test sample에 대한 label이 없더라도 confidence를 구할 수 있다.
저자들은 distance-based learning from errors (DBLE)의 calibration 효과를 비교하기 위한 baseline으로 vanilla training, MC-Dropout, Temperature scaling, Mixup, Label smoothing, Trust Score를 사용하였다.
Datasets은 MNIST, CIFAR-10, CIFAR-100, Tiny-ImageNet을 사용하였으며 하기 표에 적인 -MLP, VGG11은 사용 모델을 뜻한다.
평가지표는 accuracy, expected calibration error (ECE), negative log likelihood (NLL)를 사용하였다. ECE와 NLL의 수식은 다음과 같다. ECE는 accuracy와 confidence간의 차이로 정의되며 본 포스팅의 서두에 다루었던 내용과 일맥상통한다.
Baseline과의 성능 비교표는 다음과 같다.
Confidence model을 학습할 때 오분류된 sample 만을 사용하는 것이 월등히 좋은 성능을 보인다는 것은 아래 표를 통해서 확인할 수 있다.
논문을 작성하는 사람의 입장으로 중요한 stance인 것 같다. 본인들이 제안한 방법이 좋은 결과로 이어지지 않을 때 포기하지 않고 끊임없이 탐색해야 얻을 수 있는 결과라는 생각이 들었다.
다음으로 mix up은 두 개의 random sample에서 선형보간(linear interpolation)을 적용하여 학습하는 방법이다. 아래의 그림처럼 사용자가 임의로 설정한 에 따라서 sample image와 label을 가중합 하여 하는 방법이다. 가령 를 0.5로 할당할 경우 image를 반반 섞고 label 또한 앵무새 0.5, 도마뱀 0.5로 섞어서 학습을 진행한다. 이는 label smoothing과 동일하게 label이 0과 1로 극단값으로 치닫는 현상을 막게해주어 calibration에 도움이 된다.
전체 concept을 보면 위와 같다. 먼저 데이터를 Query set과 Support set 두 가지로 분리하며 각 data set은 중복되는 sample 없이 독립적으로 존재한다. Support set은 개별 class 들의 중심점()을 구하는데 사용된다. Query set의 sample들은 classification model을 통과해 새로운 space로 mapping()되며 앞서 구해진 class의 중심점에 가까워지도록 classification model을 학습한다.
N개의 class들의 중심점()은 classification model을 통과한 support sample들의 평균을 통해서 구해지며 notation은 다음과 같다.
먼저 prototypical loss는 embedding된 query sample 와 class 중심점 와의 거리를 softmax를 취하는 것으로 구성 된다. 즉 형태는 일반적인 softmax output 형태와 동일하나 값을 채우는 방식에서 distance가 들어간다는 차이점이 있다. 수식은 다음과 같다.
만약 강아지 이미지를 넣었을 때 강아지 class에 높은 확률을 얻기 위해서(다른 말로 1에 가까운 값을 갖음)는 위 수식에서 와 간의 distance 값이 0에 가까워야한다. Distance가 0에 가깝다는 것은 해당 sample이 강아지 class 중심점과 매우 가까운 위치에 mapping 되어있다는 것을 뜻한다.
Support set의 1000개 sample들을 통해 1부터 10까지 class의 중심점()을 구한다. Classification model은 1에 해당하는 sample이 들어올 경우 에 가깝게 mapping하도록 학습하며 나머지 class의 sample에 대해서도 동일하게 적용된다.
학습이 어느정도 진행되었다고 가정하자. Classification model을 통과한 class 1의 sample들은 자신의 중심점인 과의 거리는 가깝게 mapping이 될 것이고 나머지 (n > 1 & n < 10)과의 거리는 멀도록 mapping이 된다. 즉 모델의 출력 관점에서 바라보면 class와의 거리가 가까울 수록 softmax의 output값은 1에 가까울 것이고 거리가 먼 class일 수록 0에 가깝게 출력 되기에 class 1의 sample들은 output의 첫 번째 element에 대한 확률 값이 가장 높게 측정될 것이다.
위에서 서술한 방식으로 model이 학습이 되었다면, inference 시에는 distance를 어떻게 구하고 예측값을 도출할 수 있을까? 먼저 각 class의 중심점은 training samples를 이용하여 계산한다. 이후 test sample들을 통과시켜 mapping 된 와 가장 가까운 거리의 label로 예측을 진행한다. Notation은 다음과 같다.
따라서 mapping된 가 ground-truth center와 비슷한 위치에 mapping 되지 않을수록 오분류 될 가능성이 커지게 된다. 정리하면 test sample과 class centroid vector간의 distance에 비례하여 모델의 성능이 결정되기에 distance가 곧 model의 calibrated confidence를 표현한다고 볼 수 있다.
그렇다면 실제로 distance와 confidence와의 관계는 어떠한지 알아보자. 위 그림의 x축은 개별 sample들의 class 중심점과의 거리를 뜻하며, y축은 그 때의 test accuracy의 평균을 뜻한다. 예를 들어 CIFAR-100에서 ground-truth class 중심점과의 거리가 5인 sample들의 test accuracy 평균은 약 0.95로 해석한다. Legend의 는 test sample 의 ground-truth 중심점과의 거리를 뜻하며는 예측된 sample들을 평균낸 중심점과의 거리를 뜻한다. 도표에서 확인할 수 있듯이 ground-truth 중심점과의 거리가 멀어질수록 모델의 성능이 낮아지며, 예측된 class 중심점을 사용한 것도 그러한 경향을 보인다. 이를 통해 distance 기반의 방법은 calibrated confidence가 높다고 판단할 수 있다. 다만 예측된 class 중심점을 사용할 경우 ground-truth를 사용한 것 만큼 정확하진 않다.
저자들은 confidence를 측정하는 모델을 confidence model로 칭한다. Confidence model은 로 표현되며 sample의 mapping값 을 받아서 를 출력하는 모델이다. 이 가 크다면 해당 sample은 낮은 confidence (특정 class의 sample로 판단하기엔 어려움)를 갖는다고 이해하면 된다.
Confidence model은 classification model에서 오분류된 sample만을 사용하여 학습한다. 오분류된 sample만을 활용하는 이유는 일반적인 class imbalance 상황을 생각하면 이해하기 수월한데, classification model을 학습하게 되면 ground-truth와의 distance가 작은 sample (정분류)들이 다수, distance가 큰 sample (오분류)들은 소수가 된다. 따라서 confidence model은 소수의 오분류 sample들에 focus하여 학습하기 어려운데 confidence model의 학습 목표가 낮은 confidence sample에 대해 큰 를 부여하는 것임을 생각하면 이는 큰 단점으로 작용한다. 따라서 오분류된 sample만을 사용하는 것이다. 원문은 다음과 같다.
If all data is used, training of would be dominated by the small distances of the correctly classified samples which would make it harder for capture the larger distances for the minor mis-classified samples.
오분류된 sample들의 mapping 값 와 를 parameter로 삼는 gaussian distribution에서 sample 를 추출한다. 여기서 는 confidence model의 output임을 잊지말자.
Sampling 된 값 가 정분류가 되도록 confidence model을 update한다. (초기는 작은 값을 갖지만 update가 반복될수록 는 커진다.)
즉 오분류된 sample이면서 ground-truth class 중심점과의 거리가 멀수록 는 큰 값을 갖게 된다. 그림으로 표현하면 아래와 같다.
쉽게 설명하면 다음과 같다. 그림 (a)를 보면 sample a,b는 오분류 되었으며 c는 올바르게 분류된 상태이다. 앞에서 말했듯, confidence model은 오분류된 sample만을 사용하기에 sample a,b를 사용하여 update가 진행된다. 우선 sample a만 살펴보자. 우리는 평균 에 상응하는 를 confidence model을 통해서 도출할 수 있다. 초기의 sigma는 작은 값을 갖게 되기 때문에 해당 정규분포(그림의 점선으로 된 원)에서 sample 를 추출하더라도 decision boundary(그림 중앙의 점선)를 넘어갈 수 없다.
따라서 confidence model은 sample 가 decision boundary를 넘어서 올바르게 분류되도록 기존보다 더 큰를 출력 시키도록 학습된다. 그림(b)를 보면 confidence model이 학습된 이후 큰 값의 를 출력시켜 sample 가 decision boundary를 넘어 올바르게 분류되는 모습을 보여준다. 그림에서 알 수 있듯이 ground-truth class 중심점보다 거리가 먼 sample b가 a에 비해 더 큰 를 갖게 되는 것을 알 수 있다.
이를 수식으로 표현하면 다음과 같다. 앞서 말했듯, 와 를 파라미터로 삼는 gaussian distribution에서 sample 를 추출한다.
이후 sample 와 ground-truth class의 중심점을 이용하여 prototypical loss를 최적화 한다.
최적화 과정에서 는 고정된 parameter기 때문에 와 중심전 간의 거리가 멀수록 큰 값의 를 출력하도록가 update된다.
그런데 한가지 의문점이 있다. 위 예시에서 확인할 수 있듯 정규분포에서 sampling을 진행하기 때문에 가 운좋게 decision boundary를 넘어갈 수도 있지만 그렇지 않은 sample이 추출될 확률이 더 높지 않을까? 그렇다면 단순히 를 키우기만 하는게 맞는건가?
놀랍게도 저자들은 이러한 성질을 이용해서 confidence를 추정한다. 앞서 말한 방식으로 학습이 완료된 는 inference시 다음과 같은 식에 의해서 confidence를 출력한다.
먼저 각각의 test sample 에 대해서 mapping값 와 예측된 값으로 도출된 중심점 을 구하고, confidence model을 통해서 를 구한다. 이후 해당 정규분포에서 를 U번 sampling하여 각각의 prototypical loss의 평균으로 test sample의 confidence를 측정한다.
만약 앞에서 설명한대로 특정 sample이 오분류되고 ground-truth class centroid vector와 거리가 멀게 mapping이 되어 있어 가 큰 값을 갖는다면 U번의 prototypical loss의 차이가 클 것이며, 이를 평균낸 값은 모든 class에 대해서 비슷한 값을 갖게 된다. 즉 오분류된 sample일 수록(=중심점과의 거리가 멀수록) 특정 class에 대한 confidence가 낮게 측정된다.
위 그림을 통해서 쉽게 알아보자. 만약 오분류되고, ground-truth centroid vector와 거리가 먼 sample들의 경우 높은 sigma 값을 갖는다고 설명했다. 따라서 해당 정규분포를 따르는 sampling 값 는 추출할 때마다 매번 다른 softmax output 값을 갖게 될 것이다(위 그림의 좌측). 따라서 이를 평균내면 각 class 별 output 확률값은 낮은 confidence (특정 class에 속할 확률이 낮음)를 갖게 된다.
반면 정분류되고, ground-truth centroid vector와 거리가 가까운 sample들의 경우 낮은 sigma 값을 갖기에 sampling된 들 또한 큰 차이가 존재하지 않는다. 따라서 평균을 내더라도 특정 class에 높은 confidence를 갖게 된다.
논문의 concept 및 idea 위주로 정리하여 자세한 수식이나 내용에 오류가 있을 수 있습니다.