Keras 모델 저장하고 불러오기
/* by 3months. 2017.7.19 */
keras를 통해 MLP, CNN 등의 딥러닝 모델을 만들고, 이를 학습시켜서 모델의 weights를 생성하고 나면 이를 저장하고 싶을 때가 있습니다. 특히 weights 같은 경우는 파일 형태로 저장해놓는 것이 유용한데, 파이썬 커널을 내리는 순간 애써 만든 weights 가 모두 메모리에서 날라가 버리기 때문입니다. keras에서는 모델과 weights의 재사용을 위해 이를 파일형태로 저장하는 라이브러리를 제공하며, 이를 통해 모델과 weights를 파일 형태로 저장하고 불러올 수가 있습니다.
Keras에서 만든 모델을 저장할 때는 다음과 같은 룰을 따릅니다.
- 모델은 JSON 파일 또는 YAML 파일로 저장한다.
- Weight는 H5 파일로 저장한다.
아래는 모델을 저장하고 불러오는 실습 코드입니다.
라이브러리 임포트
from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import matplotlib.pyplot as plt import numpy import pandas as pd
데이터 로드
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data" names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class'] dataframe = pd.read_csv(url, names=names) array = dataframe.values X = array[:,0:8] Y = array[:,8]
모델 생성
# create model model = Sequential() model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu')) model.add(Dense(8, kernel_initializer='uniform', activation='relu')) model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid')) # Compile model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
모델을 JSON 파일 형식으로 만들어 저장하기
model_json = model.to_json() with open("model.json", "w") as json_file : json_file.write(model_json)
Weight를 h5 파일 포맷으로 만들어 저장하기
model.save_weights("model.h5") print("Saved model to disk")
저장된 JSON 파일로 부터 모델 로드하기
from keras.models import model_from_json json_file = open("model.json", "r") loaded_model_json = json_file.read() json_file.close() loaded_model = model_from_json(loaded_model_json)
로드한 모델에 Weight 로드하기
loaded_model.load_weights("model.h5") print("Loaded model from disk")
모델 컴파일 후 Evaluation
loaded_model.compile(loss="binary_crossentropy", optimizer="rmsprop", metrics=['accuracy']) # model evaluation score = loaded_model.evaluate(X,Y,verbose=0) print("%s : %.2f%%" % (loaded_model.metrics_names[1], score[1]*100))
'Tools > Keras' 카테고리의 다른 글
Keras - CNN ImageDataGenerator 활용하기 (11) | 2017.10.28 |
---|---|
Keras - Keras를 통한 LSTM의 구현 (24) | 2017.08.29 |
Keras와 Tensorflow 사용할 때 유용한 아나콘다 가상환경 (0) | 2017.07.01 |
Keras - Backend 설정하기 (Theano, Tensorflow) (2) | 2017.07.01 |
Keras - MNIST 데이터로 CNN(Convolutional Neural Network) Training (0) | 2017.01.22 |