Parcourir la source

Add of model backup every epochs

Jérôme BUISINE il y a 4 ans
Parent
commit
1ce9187713
3 fichiers modifiés avec 21 ajouts et 4 suppressions
  1. 2 1
      .gitignore
  2. 2 1
      custom_config.py
  3. 17 2
      train_model.py

+ 2 - 1
.gitignore

@@ -15,4 +15,5 @@ saved_models
 learned_zones
 dataset
 models_info
-results
+results
+models_backup

+ 2 - 1
custom_config.py

@@ -9,6 +9,7 @@ context_vars = vars()
 
 ## noisy_folder                    = 'noisy'
 ## not_noisy_folder                = 'notNoisy'
+backup_model_folder             = 'models_backup'
 
 # file or extensions
 
@@ -24,4 +25,4 @@ keras_epochs                    = 50
 ## keras_batch                     = 32
 ## val_dataset_size                = 0.2
 
-## keras_img_size                  = (200, 200)
+keras_img_size                  = (96, 96)

+ 17 - 2
train_model.py

@@ -6,7 +6,10 @@ import json
 
 # model imports
 import cnn_models as models
+import tensorflow as tf
+import keras
 from keras import backend as K
+from keras.callbacks import ModelCheckpoint
 from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
 
 # image processing imports
@@ -21,6 +24,11 @@ import custom_config as cfg
 
 def main():
 
+    # default keras configuration
+    config = tf.ConfigProto( device_count = {'GPU': 1 , 'CPU': 8}) 
+    sess = tf.Session(config=config) 
+    keras.backend.set_session(sess)
+
     parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
 
     parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
@@ -28,7 +36,7 @@ def main():
     parser.add_argument('--tl', type=int, help='use or not of transfer learning (`VGG network`)', default=0, choices=[0, 1])
     parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=cfg.keras_batch)
     parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=cfg.keras_epochs)
-    parser.add_argument('--val_size', type=int, help='percent of validation data during training process', default=cfg.val_dataset_size)
+    parser.add_argument('--val_size', type=float, help='percent of validation data during training process', default=cfg.val_dataset_size)
 
     args = parser.parse_args()
 
@@ -136,10 +144,17 @@ def main():
     # 2. Getting model
     #######################
 
+    if not os.path.exists(cfg.backup_model_folder):
+        os.makedirs(cfg.backup_model_folder)
+
+    filepath = os.path.join(cfg.backup_model_folder, "{0}-{epoch:02d}.hdf5".format(p_output))
+    checkpoint = ModelCheckpoint(filepath, monitor='val_auc', verbose=1, save_best_only=True, mode='max')
+    callbacks_list = [checkpoint]
+
     model = models.get_model(n_channels, input_shape, p_tl)
     model.summary()
  
-    model.fit(x_data_train, y_dataset_train.values, validation_split=p_val_size, epochs=p_epochs, batch_size=p_batch_size)
+    model.fit(x_data_train, y_dataset_train.values, validation_split=p_val_size, epochs=p_epochs, batch_size=p_batch_size, callbacks=callbacks_list)
 
     score = model.evaluate(x_data_test, y_dataset_test, batch_size=p_batch_size)