瀏覽代碼

Merge branch 'release/v0.3.0'

Jérôme BUISINE 4 年之前
父節點
當前提交
b3ae6e1ac1
共有 3 個文件被更改,包括 3 次插入5 次删除
  1. 1 1
      custom_config.py
  2. 1 3
      prediction_model.py
  3. 1 1
      train_model.py

+ 1 - 1
custom_config.py

@@ -13,7 +13,7 @@ backup_model_folder             = 'models_backup'
 
 # file or extensions
 
-perf_prediction_model_path      = 'predications_models_results.csv'
+perf_prediction_model_path      = 'predictions_models_results.csv'
 ## post_image_name_separator       = '___'
 
 # variables

+ 1 - 3
prediction_model.py

@@ -99,9 +99,7 @@ def main():
         model.load_weights(p_model_file.replace('.json', '.h5'))
 
         model.compile(loss='binary_crossentropy',
-                    optimizer='rmsprop',
-                    features=['accuracy'])
-
+                    optimizer='rmsprop')
 
     # Get results obtained from model
     y_data_prediction = model.predict(x_data)

+ 1 - 1
train_model.py

@@ -152,7 +152,7 @@ def main():
         os.makedirs(model_backup_folder)
 
     # add of callback models
-    filepath = os.path.join(cfg.backup_model_folder, p_output, p_output + "__{epoch:02d}.hdf5")
+    filepath = os.path.join(cfg.backup_model_folder, p_output, p_output + "-{val_auc:02f}__{epoch:02d}.hdf5")
     checkpoint = ModelCheckpoint(filepath, monitor='val_auc', verbose=1, save_best_only=True, mode='max')
     callbacks_list = [checkpoint]