Parcourir la source

Merge branch 'release/v0.2.5'

Jérôme BUISINE il y a 4 ans
Parent
commit
41783c22c9
4 fichiers modifiés avec 26 ajouts et 5 suppressions
  1. 2 1
      .gitignore
  2. 2 1
      custom_config.py
  3. 5 1
      generate/generate_dataset.py
  4. 17 2
      train_model.py

+ 2 - 1
.gitignore

@@ -15,4 +15,5 @@ saved_models
 learned_zones
 dataset
 models_info
-results
+results
+models_backup

+ 2 - 1
custom_config.py

@@ -9,6 +9,7 @@ context_vars = vars()
 
 ## noisy_folder                    = 'noisy'
 ## not_noisy_folder                = 'notNoisy'
+backup_model_folder             = 'models_backup'
 
 # file or extensions
 
@@ -24,4 +25,4 @@ keras_epochs                    = 50
 ## keras_batch                     = 32
 ## val_dataset_size                = 0.2
 
-## keras_img_size                  = (200, 200)
+keras_img_size                  = (96, 96)

+ 5 - 1
generate/generate_dataset.py

@@ -220,6 +220,9 @@ def main():
                                     help="list of specific param for each metric choice (See README.md for further information in 3D mode)", 
                                     default='100, 200 :: 50, 25',
                                     required=True)
+    parser.add_argument('--size', type=str, 
+                                  help="Size of input images",
+                                  default="100, 100")
     parser.add_argument('--scenes', type=str, help='List of scenes to use for training data')
     parser.add_argument('--nb_zones', type=int, help='Number of zones to use for training data set', choices=list(range(1, 17)))
     parser.add_argument('--renderer', type=str, help='Renderer choice in order to limit scenes used', choices=cfg.renderer_choices, default='all')
@@ -231,6 +234,7 @@ def main():
     p_features  = list(map(str.strip, args.features.split(',')))
     p_params   = list(map(str.strip, args.params.split('::')))
     p_scenes   = args.scenes.split(',')
+    p_size     = args.size # not necessary to split here
     p_nb_zones = args.nb_zones
     p_renderer = args.renderer
     p_random   = args.random
@@ -243,7 +247,7 @@ def main():
         if feature not in features_choices:
             raise ValueError("Unknown metric, please select a correct metric : ", features_choices)
 
-        transformations.append(Transformation(feature, p_params[id]))
+        transformations.append(Transformation(feature, p_params[id], p_size))
 
     if transformations[0].getName() == 'static':
         raise ValueError("The first transformation in list cannot be static")

+ 17 - 2
train_model.py

@@ -6,7 +6,10 @@ import json
 
 # model imports
 import cnn_models as models
+import tensorflow as tf
+import keras
 from keras import backend as K
+from keras.callbacks import ModelCheckpoint
 from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
 
 # image processing imports
@@ -21,6 +24,11 @@ import custom_config as cfg
 
 def main():
 
+    # default keras configuration
+    config = tf.ConfigProto( device_count = {'GPU': 1 , 'CPU': 8}) 
+    sess = tf.Session(config=config) 
+    keras.backend.set_session(sess)
+
     parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
 
     parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
@@ -28,7 +36,7 @@ def main():
     parser.add_argument('--tl', type=int, help='use or not of transfer learning (`VGG network`)', default=0, choices=[0, 1])
     parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=cfg.keras_batch)
     parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=cfg.keras_epochs)
-    parser.add_argument('--val_size', type=int, help='percent of validation data during training process', default=cfg.val_dataset_size)
+    parser.add_argument('--val_size', type=float, help='percent of validation data during training process', default=cfg.val_dataset_size)
 
     args = parser.parse_args()
 
@@ -136,10 +144,17 @@ def main():
     # 2. Getting model
     #######################
 
+    if not os.path.exists(cfg.backup_model_folder):
+        os.makedirs(cfg.backup_model_folder)
+
+    filepath = os.path.join(cfg.backup_model_folder, p_output + "-{epoch:02d}.hdf5")
+    checkpoint = ModelCheckpoint(filepath, monitor='val_auc', verbose=1, save_best_only=True, mode='max')
+    callbacks_list = [checkpoint]
+
     model = models.get_model(n_channels, input_shape, p_tl)
     model.summary()
  
-    model.fit(x_data_train, y_dataset_train.values, validation_split=p_val_size, epochs=p_epochs, batch_size=p_batch_size)
+    model.fit(x_data_train, y_dataset_train.values, validation_split=p_val_size, epochs=p_epochs, batch_size=p_batch_size, callbacks=callbacks_list)
 
     score = model.evaluate(x_data_test, y_dataset_test, batch_size=p_batch_size)