瀏覽代碼

Update of way of generating data and train model

Jérôme BUISINE 5 年之前
父節點
當前提交
a6d566fc00
共有 2 個文件被更改,包括 8 次插入12 次删除
  1. 5 9
      generate/generate_dataset.py
  2. 3 3
      train_model.py

+ 5 - 9
generate/generate_dataset.py

@@ -29,7 +29,7 @@ min_max_filename        = cfg.min_max_filename_extension
 
 # define all scenes values
 scenes_list             = cfg.scenes_names
-scenes_indexes          = cfg.scenes_indices
+scenes_indices          = cfg.scenes_indices
 dataset_path            = cfg.dataset_path
 zones                   = cfg.zones_indices
 seuil_expe_filename     = cfg.seuil_expe_filename
@@ -39,7 +39,7 @@ output_data_folder      = cfg.output_data_folder
 
 generic_output_file_svd = '_random.csv'
 
-def generate_data_model(_scenes_list, _filename, _transformations, _scenes, _nb_zones = 4, _random=0):
+def generate_data_model(_filename, _transformations, _scenes_list, _nb_zones = 4, _random=0):
 
     output_train_filename = _filename + ".train"
     output_test_filename = _filename + ".val"
@@ -187,7 +187,7 @@ def generate_data_model(_scenes_list, _filename, _transformations, _scenes, _nb_
                     
                     line = line + '\n'
 
-                    if id_zone < _nb_zones and folder_scene in _scenes:
+                    if id_zone < _nb_zones:
                         train_file_data.append(line)
                     else:
                         test_file_data.append(line)
@@ -225,7 +225,6 @@ def main():
                                   default="100, 100")
     parser.add_argument('--scenes', type=str, help='List of scenes to use for training data')
     parser.add_argument('--nb_zones', type=int, help='Number of zones to use for training data set', choices=list(range(1, 17)))
-    parser.add_argument('--renderer', type=str, help='Renderer choice in order to limit scenes used', choices=cfg.renderer_choices, default='all')
     parser.add_argument('--random', type=int, help='Data will be randomly filled or not', choices=[0, 1])
 
     args = parser.parse_args()
@@ -252,10 +251,7 @@ def main():
     if transformations[0].getName() == 'static':
         raise ValueError("The first transformation in list cannot be static")
 
-    # list all possibles choices of renderer
-    scenes_list = dt.get_renderer_scenes_names(p_renderer)
-    scenes_indices = dt.get_renderer_scenes_indices(p_renderer)
-
+    # Update: not use of renderer scenes list
     # getting scenes from indexes user selection
     scenes_selected = []
 
@@ -264,7 +260,7 @@ def main():
         scenes_selected.append(scenes_list[index])
 
     # create database using img folder (generate first time only)
-    generate_data_model(scenes_list, p_filename, transformations, scenes_selected, p_nb_zones, p_random)
+    generate_data_model(p_filename, transformations, scenes_selected, p_nb_zones, p_random)
 
 if __name__== "__main__":
     main()

+ 3 - 3
train_model.py

@@ -169,11 +169,11 @@ def main():
         print("Restart from epoch ", last_epoch)
 
     # concatenate train and validation data (`validation_split` param will do the separation into keras model)
-    y_data = y_dataset_train.values + y_dataset_val.values
-    x_data = x_data_train + y_data_train
+    y_data = np.concatenate([y_dataset_train.values, y_dataset_val.values])
+    x_data = np.concatenate([x_data_train, x_data_val])
 
     # validation split parameter will use the last `%` data, so here, data will really validate our model
-    model.fit(x_data_train, y_dataset_train.values, validation_split=validation_split, initial_epoch=initial_epoch, epochs=p_epochs, batch_size=p_batch_size, callbacks=callbacks_list)
+    model.fit(x_data, y_data, validation_split=validation_split, initial_epoch=initial_epoch, epochs=p_epochs, batch_size=p_batch_size, callbacks=callbacks_list)
 
     score = model.evaluate(x_data_val, y_dataset_val, batch_size=p_batch_size)