반응형

Model Calibration


예측모형 (predicted model) 을 어떻게 평가할 수 있을까? 가장 직관적이면서 많이 쓰이는 평가 방법은 정확도 (accuracy), 즉, 예측한 것중 몇퍼센트나 맞았는가에 관한 지표일 것이다. 하지만 좋은 모델이란 정확해야할 뿐만아니라 잘 보정 (calibration) 되어야할 필요가 있다. 


calibration 을 평가하기 위해 사용되는 calibration plot 은 예측된 확률과, 실제 확률의 관계를 보여준다. 이를 통해 모델의 예측이 얼마나 "현실적인지" 를 측정하게 된다. 예를 들어, 이미지를 인풋으로 받아 개와 고양이를 분류하는 모델을 생각해보자. 어떤 이미지에 대해 0.8 의 확률로 개라고 반환했다면, "정말 이 이미지가 개일 확률이 0.8 인가?" 에 대한 답을 주는 것이 calibration plot 이다. 정확도란 (일반적으로 이진 분류에 관하여) 50 % 를 cutoff 로 사용하여, 예측을 A와 B 클래스로 나누어, 실제 값이랑 맞는지를 확인하는 것이지만, calibration 은 보다 면밀하게 모델의 결과값을 검증하는 과정이라고 볼 수 있다. 


Calibration plot 


만약 데이터의 실제 정답이 알려져 있다면, Calibration 을 평가하기 위해 Calibration plot 을 많이 그리게 되는데, 일반적인 방법은 다음과 같다. 


1. 모델의 예측값을 기준으로 [0,10%], (10,20%], (20,30%], … (90,100%] 에 맞게 데이터를 분할한다 (이를 binning 이라고도 한다).  

2. 각 카테고리에서 예측하고자 한 클래스의 비율 (event rate) 를 계산한다 (예제의 경우 개의 비율을 계산한다).

3. calibration plot을 그린다 : 각 카테고리에서의 중앙 값 (5 %, 15 %, 20 % ...) 를 x 로 놓고, event rate 를 y 로 놓고 그린 그림이다. 

4. calibration plot 의 선이 일직선 (45◦)임을 확인한다. 


Example


R을 통해 Calibration 을 실제로 해보자. 


diabetes.csv


실습 데이터는 Pima diabetes 데이터셋을 이용해보겠다. Pima diabetes 데이터셋은 사람들의 임상정보와 당뇨병 여부에 관한 정보를 갖고 있는 데이터셋이다. 이 때 당뇨병 여부를 예측하는 모형을 로지스틱 회귀분석 및 랜덤포레스트을 이용해 구축하고, 이 두 모델의 정확도 및 Calibration 을 평가해보자. 


데이터 로드 및 train/test split 

  • 이 데이터셋의 경우, missing value 가 많은 것이 특징이다. 
  • 아래 코드는 평균으로 missing 을 채워넣는 mean imputation 을 수행하고 train/test 를 50:50으로 나누는 코드이다. 
suppressPackageStartupMessages(library(tidyverse))
library(data.table) 
data <- readr::read_csv("../PimaDiabetes/diabetes.csv") 
data$Outcome <- factor(data$Outcome)

## Imputation
fix_missing <- function(x, missing_value) { 
  x[x == missing_value] <- NA 
  x 
} 
cols <- colnames(data)[1:8]
data[, cols] <- lapply(data[, cols], fix_missing, 0)

impute_mean <- function(x) {
    x[is.na(x)] <- mean(x, na.rm = TRUE)
    return(x)
}
data[, cols] <- lapply(data[, cols], impute_mean)
data %>% head

##  Train/Test Split
set.seed(123)
smp_size <- floor(0.5 * nrow(data))
train_ind <- sample(seq_len(nrow(data)), size = smp_size)
train <- data[trai n_ind, ]
test <- data[-train_ind, ]


로지스틱 회귀분석 모형 구축 및 test set 에 대한 예측

  • Pregnancies + Glucose + BloodPressure + Insulin + BMI + DiabetesPedigreeFunction + Age 를 통해 Outcome 을 예측하는 모형을 만든다. 
lrmodel <- glm(data = train, Outcome ~ Pregnancies + Glucose + BloodPressure + Insulin + BMI + DiabetesPedigreeFunction + Age, family = binomial("logit"))

x = predict(lrmodel, newdata = test)
p = (1 / (1+exp(-x)))
test <- test %>% mutate(lrmodel = p)

랜덤 포레스트 모형 구축 및 test set 에 대한 예측 

  • 랜덤 포레스트의 hyperparameter 인 mtry 와 ntree 는 적절한 값을 선택한다. 
library(randomForest)

rfmodel = randomForest(Outcome ~ Pregnancies + Glucose + BloodPressure + Insulin + BMI + DiabetesPedigreeFunction + Age
                      , data = train, mtry = floor(sqrt(7)), ntree = 500, importance = T)

p = predict(rfmodel, newdata = test, type = "prob")[, 2]
test <- test %>% mutate(rfmodel = p)

cutoff 정하기

  • 모형은 0~1사이의 확률을 의미하는 값을 내보내는데, 여기에 threshold 를 적용해서 0 또는 1로 변환한다. 
test <- test %>% mutate(
  rfclass = if_else(rfmodel >= 0.5, 1, 0),
  lrclass = if_else(lrmodel >= 0.5, 1, 0)
)
test$rfclass <- factor(test$rfclass)
test$lrclass <- factor(test$lrclass)


로지스틱 회귀분석 정확도 

  • 정확도는 최종 예측값 (0 또는 1) 을 기준으로, 예측한 값중 실제 정답으로 맞춘 비율을 의미하는 값이다. 
  • 로지스틱 회귀분석의 경우, 78.9 % 의 정확도를 보여준다. 

library(caret) confusionMatrix(test$lrclass, test$Outcome)

Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 227  58
         1  23  76
                                          
               Accuracy : 0.7891          
                 95% CI : (0.7448, 0.8288)
    No Information Rate : 0.651           
    P-Value [Acc > NIR] : 2.494e-09       
                                          
                  Kappa : 0.5058          
                                          
 Mcnemar's Test P-Value : 0.0001582       
                                          
            Sensitivity : 0.9080          
            Specificity : 0.5672          
         Pos Pred Value : 0.7965          
         Neg Pred Value : 0.7677          
             Prevalence : 0.6510          
         Detection Rate : 0.5911          
   Detection Prevalence : 0.7422          
      Balanced Accuracy : 0.7376          
                                          
       'Positive' Class : 0  


로지스틱 회귀분석 Calibration plot

  • Calibration plot 을 그릴 수 있는 방법은 여러가지가 있지만, caret 패키지의 calibration 함수를 통해 쉽게 그려볼 수 있다. 

calibration 함수는 아래의 calibration plot 을 그릴 수 있는 정보를 dataframe으로 만들어 반환해준다. 
  • 모델의 예측값을 기준으로 [0,10%], (10,20%], (20,30%], … (90,100%] 에 맞게 데이터를 분할한다 (이를 binning 이라고도 한다).  
  • calibration plot을 그린다 : 각 카테고리에서의 중앙 값 (5 %, 15 %, 20 % ...) 를 x 로 놓고, event rate 를 y 로 놓고 그린 그림이다. 
library(caret)

cal_plot_data_lr = calibration(Outcome ~ lrmodel, 
  data = test, cuts = seq(0, 1, by=0.1), class = 1)$data 

ggplot() + xlab("Bin Midpoint") +
  geom_line(data = cal_plot_data_lr, aes(midpoint, Percent),
            color = "#F8766D") +
  geom_point(data = cal_plot_data_lr, aes(midpoint, Percent),
            color = "#F8766D", size = 3) +
  geom_line(aes(c(0, 100), c(0, 100)), linetype = 2, 
            color = 'grey50')

랜덤포레스트 정확도

  • 랜덤포레스트의 경우 로지스틱 회귀분석보다 조금 작은 0.77 % 의 정확도를 보인다. 
confusionMatrix(test$rfclass, test$Outcome)
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 213  51
         1  37  83
                                          
               Accuracy : 0.7708          
                 95% CI : (0.7255, 0.8119)
    No Information Rate : 0.651           
    P-Value [Acc > NIR] : 2.421e-07       
                                          
                  Kappa : 0.4831          
                                          
 Mcnemar's Test P-Value : 0.1658          
                                          
            Sensitivity : 0.8520          
            Specificity : 0.6194          
         Pos Pred Value : 0.8068          
         Neg Pred Value : 0.6917          
             Prevalence : 0.6510          
         Detection Rate : 0.5547          
   Detection Prevalence : 0.6875          
      Balanced Accuracy : 0.7357          
                                          
       'Positive' Class : 0               
                                          


랜덤포레스트 Calibration plot

  • 랜덤포레스트에서도 같은 방법으로 calibration plot 을 그릴 수 있다. 
  • 정확도는 랜덤포레스트에서 약간 작았지만, Calibration 은 더 좋은 모습을 보인다.
  • 하지만 train/test 의 비율, hyperparameter 구성에 따라 calibration 이 달라지니, 다양한 세팅에서 검증해볼 필요가 있다. 
cal_plot_data_rf = calibration(Outcome ~ rfmodel, 
  data = test, class = 1)$data

ggplot() + xlab("Bin Midpoint") +
  geom_line(data = cal_plot_data_rf, aes(midpoint, Percent),
            color = "#F8766D") +
  geom_point(data = cal_plot_data_rf, aes(midpoint, Percent),
            color = "#F8766D", size = 3) +
  geom_line(aes(c(0, 100), c(0, 100)), linetype = 2, 
            color = 'grey50')


https://medium.com/optima-blog/model-calibration-4d710a76c54

http://appliedpredictivemodeling.com/blog?offset=1532965627474


반응형