Parcourir la source

use of params to fit well images input of model

Jérôme BUISINE il y a 3 ans
Parent
commit
e1296990b8
1 fichiers modifiés avec 15 ajouts et 32 suppressions
  1. 15 32
      train_model.py

+ 15 - 32
train_model.py

@@ -56,7 +56,7 @@ def main():
     parser.add_argument('--tl', type=int, help='use or not of transfer learning (`VGG network`)', default=0, choices=[0, 1])
     parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=64)
     parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=30)
-    parser.add_argument('--chanels', type=int, help="given number of chanels if necessary", default=0)
+    parser.add_argument('--chanels', type=int, help="given number of ordered chanels for each input images (example: '1,3,3')", required=True)
     parser.add_argument('--size', type=str, help="Size of input images", default="100, 100")
     parser.add_argument('--val_size', type=float, help='percent of validation data during training process', default=0.3)
 
@@ -68,7 +68,7 @@ def main():
     p_tl          = args.tl
     p_batch_size  = args.batch_size
     p_epochs      = args.epochs
-    p_chanels     = args.chanels
+    p_chanels     = list(map(int, args.chanels.split(',')))
     p_size        = args.size.split(',')
     p_val_size    = args.val_size
 
@@ -94,10 +94,7 @@ def main():
     print("--Reading all images data...")
 
     # getting number of chanel
-    if p_chanels == 0:
-        n_chanels = len(dataset_train[1][1].split('::'))
-    else:
-        n_chanels = p_chanels
+    n_chanels = sum(p_chanels)
 
     print("-- Number of chanels : ", n_chanels)
     img_width, img_height = [ int(s) for s in p_size ]
@@ -145,44 +142,30 @@ def main():
 
     final_df_train = dataset_train
     final_df_test = dataset_test
-    
-    def load_multiple_greyscale(x):
-        # update progress
-        global n_counter
-        n_counter += 1
-        write_progress(n_counter / float(total_samples))
-        return [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')]
 
-    def load_greyscale(x):
+    def load_images(x):
         # update progress
         global n_counter
         n_counter += 1
         write_progress(n_counter / float(total_samples))
-        return cv2.imread(x, cv2.IMREAD_GRAYSCALE)
 
-    def load_rgb(x):
-        # update progress
-        global n_counter
-        n_counter += 1
-        write_progress(n_counter / float(total_samples))
-        return cv2.imread(x)
+        images = []
+        for i, path in enumerate(x.split('::')):
+            if p_chanels[i] > 1:
+                img = cv2.imread(path)
+            else:
+                img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
+            images.append(img)
+
+        return images
 
 
     print('---- Loading dataset.... ----')
     print('-----------------------------\n')
 
     # check if specific number of chanels is used
-    if p_chanels == 0:
-        # `::` is the separator used for getting each img path
-        if n_chanels > 1:
-            final_df_train[1] = final_df_train[1].apply(lambda x: load_multiple_greyscale(x))
-            final_df_test[1] = final_df_test[1].apply(lambda x: load_multiple_greyscale(x))
-        else:
-            final_df_train[1] = final_df_train[1].apply(lambda x: load_greyscale(x))
-            final_df_test[1] = final_df_test[1].apply(lambda x: load_greyscale(x))
-    else:
-        final_df_train[1] = final_df_train[1].apply(lambda x: load_rgb(x))
-        final_df_test[1] = final_df_test[1].apply(lambda x: load_rgb(x))
+    final_df_train[1] = final_df_train[1].apply(lambda x: load_images(x))
+    final_df_test[1] = final_df_test[1].apply(lambda x: load_images(x))
 
     # reshape array data
     final_df_train[1] = final_df_train[1].apply(lambda x: np.array(x).reshape(input_shape))