Parcourir la source

Merge branch 'release/v0.3.0'

Jérôme BUISINE il y a 4 ans
Parent
commit
16b6e85601
3 fichiers modifiés avec 3 ajouts et 5 suppressions
  1. 1 1
      custom_config.py
  2. 1 3
      prediction_model.py
  3. 1 1
      train_model.py

+ 1 - 1
custom_config.py

@@ -13,7 +13,7 @@ backup_model_folder             = 'models_backup'
 
 # file or extensions
 
-perf_prediction_model_path      = 'predications_models_results.csv'
+perf_prediction_model_path      = 'predictions_models_results.csv'
 ## post_image_name_separator       = '___'
 
 # variables

+ 1 - 3
prediction_model.py

@@ -99,9 +99,7 @@ def main():
         model.load_weights(p_model_file.replace('.json', '.h5'))
 
         model.compile(loss='binary_crossentropy',
-                    optimizer='rmsprop',
-                    features=['accuracy'])
-
+                    optimizer='rmsprop')
 
     # Get results obtained from model
     y_data_prediction = model.predict(x_data)

+ 1 - 1
train_model.py

@@ -152,7 +152,7 @@ def main():
         os.makedirs(model_backup_folder)
 
     # add of callback models
-    filepath = os.path.join(cfg.backup_model_folder, p_output, p_output + "__{epoch:02d}.hdf5")
+    filepath = os.path.join(cfg.backup_model_folder, p_output, p_output + "-{val_auc:02f}__{epoch:02d}.hdf5")
     checkpoint = ModelCheckpoint(filepath, monitor='val_auc', verbose=1, save_best_only=True, mode='max')
     callbacks_list = [checkpoint]