반응형

딥러닝에서 클래스 불균형을 다루는 방법


현실 데이터에는 클래스 불균형 (class imbalance) 문제가 자주 있다. 어떤 데이터에서 각 클래스 (주로 범주형 반응 변수) 가 갖고 있는 데이터의 양에 차이가 큰 경우, 클래스 불균형이 있다고 말한다. 예를 들어, 병원에서 질병이 있는 사람과 질병이 없는 사람의 데이터를 수집했다고 하자. 일반적으로 질병이 있는 사람이 질병이 없는 사람에 비해 적다. 비단 병원 데이터뿐 아니라 대부분의 "현실 데이터" 에 클래스 불균형 문제가 있다. 


클래스 균형이 필요한가?


왜 데이터가 클래스 균형을 이루어야할까? 그리고 언제 클래스 균형이 필요할까? 핵심은 다음과 같다. 클래스 균형 클래스 균형은 소수의 클래스에 특별히 더 큰 관심이 있는 경우에 필요하다. 


예를 들어 현재 재정 상황 및 집의 특성 등을 토대로 집을 사야할지 말아야할지를 예측하는 모델을 만들고 싶다고 하자. 사지말라고 예측하는 것과 사라고 예측하는 것은 그 무게가 다르다. 집을 사라고 예측하는 것은 훨씬 더 큰 리스크를 수반한다. 잘못된 투자는 큰 손실로 이루어질 수 있기 때문이다. 따라서 '집을 사라' 라고 예측하는 것에 대해서는 더 큰 정확도를 가져야한다. 하지만 데이터가 '집을 사지마라' 클래스에 몰려있는 경우, '집을 사지마라' 예측에 있어서는 높은 정확도를 가질 수 있어도 '집을 사라' 라고 예측하는 것에 관해서는 예측 성능이 좋지 않게 된다. 따라서 클래스 불균형이 있는 경우, 클래스에 따라 정확도가 달라지게 된다. 이를 해결하기 위해서는 따라서 '집을 사라' 클래스에는 더욱 큰 비중 (weight) 를 두고 정확한 예측을 할 수 있도록 만들어야한다.   


만약 소수 클래스에 관심이 없다면 어떻게 할까? 예를 들어, 이미지 분류 문제를 예로 들어보자. 그리고 오직, 전체 예측의 정확도 (accuracy) 에만 관심이 있다고 하자. 이 경우에는 굳이 클래스 균형을 맞출 필요가 없다. 왜냐하면 트레이닝 데이터에 만에 데이터를 위주로 학습하면, 모델의 정확도가 높아질 것이기 때문이다. 따라서 이런 경우에는 소수 클래스를 무시하더라도 전체 성능에 큰 영향을 주지 않기 때문에, 클래스 균형을 맞추는 것이 굳이 필요하지 않다고 할 수 있다. 


클래스 균형이 필요한 상황과 불필요한 상황을 예로 들어 설명했다. 다음으로는 딥러닝에서 클래스 균형을 맞추기 위한 두 가지 테크닉을 소개한다. 


(1) Weight balancing


Weight balancing 은 training set 의 각 데이터에서 loss 를 계산할 때 특정 클래스의 데이터에 더 큰 loss 값을 갖도록 하는 방법이다. 예를 들어, 이전 예에서 집을 사라는 클래스에 관해서는 더 큰 정확도가 필요하므로, 트레이닝 할 때, 집을 사라는 클래스의 데이터에 관해서는 loss 가 더 크도록 만드는 것이다. 이를 구현하는 한 가지 간단한 방법은 원하는 클래스의 데이터의 loss 에는 특정 값을 곱하고, 이를 통해 딥러닝 모델을 트레이닝하는 것이다.  


예를 들어, "집을 사라" 클래스에는 75 %의 가중치를 두고, "집을 사지마라" 클래스에는 25 %의 가중치를 둘 수 있다. 이를 python keras 를 통해 구현하면 아래와 같다. class_weight 라는 dictionary 를 만들고, keas model 의 class_weight parameter 로 넣어주면 된다. 


import keras

class_weight = {"buy": 0.75,
                "don't buy": 0.25}

model.fit(X_train, Y_train, epochs=10, batch_size=32, class_weight=class_weight)


물론 이 값을 예를 든 값이며, 분야와 최종 성능을 고려해 가중치 비율의 최적 세팅을 찾으면 된다. 다른 한 가지 방법은 클래스의 비율에 따라 가중치를 두는 방법인데, 예를 들어, 클래스의 비율이 1:9 라면 가중치를 9:1로 줌으로써 적은 샘플 수를 가진 클래스를 전체 loss 에 동일하게 기여하도록 할 수 있다. 


Weight balancing 에 사용할 수 있는 다른 방법은 Focal loss 를 사용하는 것이다. Focal loss 의 메인 아이디어는 다음과 같다. 다중 클래스 분류 문제에서, A, B, C 3개의 클래스가 존재한다고 하자. A 클래스는 상대적으로 분류하기 쉽고, B, C 클래스는 쉽다고 하자. 총 100번의 epoch 에서 단지 10번의 epoch 만에 validation set 에 대해 99 % 의 정확도를 얻었다. 그럼에도 불구하고 나머지 90 epoch 에 대해 A 클래스는 계속 loss 의 계산에 기여한다. 만약 상대적으로 분류하기 쉬운 A 클래스의 데이터 대신, B, C 클래스의 데이터에 더욱 집중을 해서 loss 를 계산을 하면 전체적인 정확도를 더 높일 수 있지 않을까? 예를 들어 batch size 가 64 라고 하면, 64 개의 sample 을 본 후, loss 를 계산해서 backpropagation 을 통해 weight 를 업데이트 하게 되는데 이 때, 이 loss 의 계산에 현재까지의 클래스 별 정확도를 고려한 weight 를 줌으로서 전반적인 모델의 정확도를 높이고자 하는 것이다. 



Focal loss 는 어떤 batch 의 트레이닝 데이터에 같은 weight 를 주지 않고, 분류 성능이 높은 클래스에 대해서는 down-weighting 을 한다. 이 때, gamma (위 그림) 를 주어, 이  down-weighting 의 정도를 결정한다. 이 방법은 분류가 힘든 데이터에 대한 트레닝을 강조하는 효과가 있다. Focal loss 는 Keras 에서 아래와 같은 custom loss function 을 정의하고 loss parameter 에 넣어줌으로써 구현할 수 있다. 

import keras
from keras import backend as K
import tensorflow as tf

# Define our custom loss function
def focal_loss(y_true, y_pred):
    gamma = 2.0, alpha = 0.25
    pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
    pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
    return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0))

# Compile our model
adam = Adam(lr=0.0001)
model.compile(loss=[focal_loss], metrics=["accuracy"], optimizer=adam) 


(2) Over and under sampling


클래스 불균형을 다루기 위한 다른 방법은 바로 샘플링을 이용하는 것이다.


Under and and Over Sampling


예를 들어, 위 그림에서 파란색 데이터가 주황색 데이터에비해 양이 현저히 적다. 이 경우 두 가지 방법 - Undersampling, Oversampling 으로 샘플링을 할 수 있다. 


Undersampling 은 Majority class (파란색 데이터) 의 일부만을 선택하고, Minority class (주황색 데이터) 는 최대한 많은 데이터를 사용하는 방법이다. 이 때 Undersampling 된 파란색 데이터가 원본 데이터와 비교해 대표성이 있어야한다. Oversampling 은 Minority class 의 복사본을 만들어, Majority class 의 수만큼 데이터를 만들어주는 것이다. 똑같은 데이터를 그대로 복사하는 것이기 때문에 새로운 데이터는 기존 데이터와 같은 성질을 갖게된다.  


Reference

https://towardsdatascience.com/handling-imbalanced-datasets-in-deep-learning-f48407a0e758

https://towardsdatascience.com/methods-for-dealing-with-imbalanced-data-5b761be45a18

반응형