Keras 모델 저장하고 불러오기


/* by 3months. 2017.7.19 */


keras를 통해 MLP, CNN 등의 딥러닝 모델을 만들고, 이를 학습시켜서 모델의 weights를 생성하고 나면 이를 저장하고 싶을 때가 있습니다. 특히 weights 같은 경우는 파일 형태로 저장해놓는 것이 유용한데, 파이썬 커널을 내리는 순간 애써 만든 weights 가 모두 메모리에서 날라가 버리기 때문입니다. keras에서는 모델과 weights의 재사용을 위해 이를 파일형태로 저장하는 라이브러리를 제공하며, 이를 통해 모델과 weights를 파일 형태로 저장하고 불러올 수가 있습니다.


Keras에서 만든 모델을 저장할 때는 다음과 같은 룰을 따릅니다.

  • 모델은 JSON 파일 또는 YAML 파일로 저장한다.
  • Weight는 H5 파일로 저장한다.


아래는 모델을 저장하고 불러오는 실습 코드입니다.



라이브러리 임포트

from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
import pandas as pd


데이터 로드

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"
names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']

dataframe = pd.read_csv(url, names=names)
array = dataframe.values

X = array[:,0:8]
Y = array[:,8]


모델 생성

# create model
model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))

# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])


모델을 JSON 파일 형식으로 만들어 저장하기

model_json = model.to_json()
with open("model.json", "w") as json_file : 
    json_file.write(model_json)


Weight를 h5 파일 포맷으로 만들어 저장하기

model.save_weights("model.h5")
print("Saved model to disk")


저장된 JSON 파일로 부터 모델 로드하기

from keras.models import model_from_json json_file = open("model.json", "r") loaded_model_json = json_file.read() json_file.close() loaded_model = model_from_json(loaded_model_json)


로드한 모델에 Weight 로드하기

loaded_model.load_weights("model.h5") print("Loaded model from disk")


모델 컴파일 후 Evaluation

loaded_model.compile(loss="binary_crossentropy", optimizer="rmsprop", metrics=['accuracy'])

# model evaluation
score = loaded_model.evaluate(X,Y,verbose=0)

print("%s : %.2f%%" % (loaded_model.metrics_names[1], score[1]*100))


  • 어른아이 2019.09.09 17:31

    감사합니다!!!

  • 2020.03.24 16:10

    model.save_weights("model.h5")


    여기에서 int로 index를 파일이름에 넣는 방법이 있을까요?

    • f-string 2020.03.30 16:22

      f-string을 이용하셔서 f'model_{index}.h5' 처럼 넣으시면 될것 같습니다.