Parcourir la source

update dropout for model and saves of model results

Jérôme BUISINE il y a 3 ans
Parent
commit
f076c19296
1 fichiers modifiés avec 21 ajouts et 18 suppressions
  1. 21 18
      train_lstm_weighted.py

+ 21 - 18
train_lstm_weighted.py

@@ -103,28 +103,28 @@ def create_model(_input_shape):
 
     model.add(ConvLSTM2D(filters=100, kernel_size=(3, 3),
                    input_shape=_input_shape,
-                   dropout=0.4,
-                   recurrent_dropout=0.4,
+                   dropout=0.5,
+                   recurrent_dropout=0.5,
                    padding='same', return_sequences=True))
     model.add(BatchNormalization())
 
     model.add(ConvLSTM2D(filters=50, kernel_size=(3, 3),
-                    dropout=0.4,
-                    recurrent_dropout=0.4,
+                    dropout=0.5,
+                    recurrent_dropout=0.5,
                     padding='same', return_sequences=True))
     model.add(BatchNormalization())
-    model.add(Dropout(0.4))
+    model.add(Dropout(0.5))
 
     model.add(Conv3D(filters=20, kernel_size=(3, 3, 3),
                 activation='sigmoid',
                 padding='same', data_format='channels_last'))
-    model.add(Dropout(0.4))
+    model.add(Dropout(0.5))
 
     model.add(Flatten())
     model.add(Dense(512, activation='sigmoid'))
-    model.add(Dropout(0.4))
+    model.add(Dropout(0.5))
     model.add(Dense(128, activation='sigmoid'))
-    model.add(Dropout(0.4))
+    model.add(Dropout(0.5))
     model.add(Dense(1, activation='sigmoid'))
     model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])
 
@@ -249,16 +249,6 @@ def main():
     print('All ACC:', acc_all)
     print('All AUC:', auc_all)
 
-
-    # save model results
-    if not os.path.exists(cfg.output_results_folder):
-        os.makedirs(cfg.output_results_folder)
-
-    results_filename = os.path.join(cfg.output_results_folder, cfg.results_filename)
-
-    with open(results_filename, 'a') as f:
-        f.write(p_output + ';' + str(acc_train) + ';' + str(auc_train) + ';' + str(acc_test) + ';' + str(auc_test) + '\n')
-
     # save acc metric information
     plt.plot(history.history['accuracy'])
     plt.plot(history.history['val_accuracy'])
@@ -276,5 +266,18 @@ def main():
 
     dump(model, os.path.join(cfg.output_models, p_output + '.joblib'))
 
+    # save model results
+    if not os.path.exists(cfg.output_results_folder):
+        os.makedirs(cfg.output_results_folder)
+    
+    results_filename_path = os.path.join(cfg.output_results_folder, cfg.results_filename)
+
+    if not os.path.exists(results_filename_path):
+        with open(results_filename_path, 'w') as f:
+            f.write(cfg.perf_train_header_file)
+
+    with open(results_filename_path, 'a') as f:
+        f.write(p_output + ';' + str(acc_train) + ';' + str(auc_train) + ';' + str(acc_test) + ';' + str(auc_test) + '\n')
+
 if __name__ == "__main__":
     main()