Parcourir la source

specify chanels images input for window in LSTM

Jérôme BUISINE il y a 3 ans
Parent
commit
1f55a401f3
1 fichiers modifiés avec 13 ajouts et 5 suppressions
  1. 13 5
      train_lstm_weighted.py

+ 13 - 5
train_lstm_weighted.py

@@ -5,6 +5,7 @@ import pandas as pd
 import os
 import ctypes
 from PIL import Image
+import cv2
 
 from keras import backend as K
 import matplotlib.pyplot as plt
@@ -50,7 +51,7 @@ def write_progress(progress):
     sys.stdout.write("\033[F")
 
 
-def build_input(df, seq_norm):
+def build_input(df, seq_norm, p_chanels):
     """Convert dataframe to numpy array input with timesteps as float array
     
     Arguments:
@@ -76,9 +77,14 @@ def build_input(df, seq_norm):
             seq_elems = []
 
             # for each element in sequence data
-            for img_path in column:
-                img = Image.open(img_path)
+            for i, img_path in enumerate(column):
+
                 # seq_elems.append(np.array(img).flatten())
+                if p_chanels[i] > 1:
+                    img = cv2.imread(img_path)
+                else:
+                    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
+                
                 seq_elems.append(np.array(img) / 255.)
 
             #seq_arr.append(np.array(seq_elems).flatten())
@@ -178,6 +184,7 @@ def main():
     parser.add_argument('--train', type=str, help='input train dataset', required=True)
     parser.add_argument('--test', type=str, help='input test dataset', required=True)
     parser.add_argument('--output', type=str, help='output model name', required=True)
+    parser.add_argument('--chanels', type=str, help="given number of ordered chanels (example: '1,3,3') for each element of window", required=True)
     parser.add_argument('--epochs', type=int, help='number of expected epochs', default=30)
     parser.add_argument('--batch_size', type=int, help='expected batch size for training model', default=64)
     parser.add_argument('--seq_norm', type=int, help='normalization sequence by features', choices=[0, 1], default=0)
@@ -187,6 +194,7 @@ def main():
     p_train        = args.train
     p_test         = args.test
     p_output       = args.output
+    p_chanels     = list(map(int, args.chanels.split(',')))
     p_epochs       = args.epochs
     p_batch_size   = args.batch_size
     p_seq_norm     = bool(args.seq_norm)
@@ -236,11 +244,11 @@ def main():
 
     # split dataset into X_train, y_train, X_test, y_test
     X_train_all = final_df_train.loc[:, 1:].apply(lambda x: x.astype(str).str.split('::'))
-    X_train_all = build_input(X_train_all, p_seq_norm)
+    X_train_all = build_input(X_train_all, p_seq_norm, p_chanels)
     y_train_all = final_df_train.loc[:, 0].astype('int')
 
     X_test = final_df_test.loc[:, 1:].apply(lambda x: x.astype(str).str.split('::'))
-    X_test = build_input(X_test, p_seq_norm)
+    X_test = build_input(X_test, p_seq_norm, p_chanels)
     y_test = final_df_test.loc[:, 0].astype('int')
 
     input_shape = (X_train_all.shape[1], X_train_all.shape[2], X_train_all.shape[3], X_train_all.shape[4])