Parcourir la source

update way of saving model using Keras API

Jérôme BUISINE il y a 3 ans
Parent
commit
56b71d2491
2 fichiers modifiés avec 3 ajouts et 7 suppressions
  1. 1 5
      prediction/estimate_thresholds_lstm.py
  2. 2 2
      train_lstm_weighted.py

+ 1 - 5
prediction/estimate_thresholds_lstm.py

@@ -91,11 +91,7 @@ def main():
     # 2. load model and compile it
 
     # TODO : check kind of model
-    model = joblib.load(p_model)
-    model.compile(loss='binary_crossentropy',
-                  optimizer='rmsprop',
-                  metrics=['accuracy'])
-    # model = load_model(p_model)
+    model = load_model(p_model)
     # model.compile(loss='binary_crossentropy',
     #               optimizer='rmsprop',
     #               metrics=['accuracy'])

+ 2 - 2
train_lstm_weighted.py

@@ -349,11 +349,11 @@ def main():
     model_history = os.path.join(cfg.output_results_folder, p_output + '.png')
     plt.savefig(model_history)
 
-    # save model using joblib
+    # save model using keras API
     if not os.path.exists(cfg.output_models):
         os.makedirs(cfg.output_models)
 
-    dump(model, os.path.join(cfg.output_models, p_output + '.joblib'))
+    model.save(os.path.join(cfg.output_models, p_output + '.h5'))
 
     # save model results
     if not os.path.exists(cfg.output_results_folder):