Browse Source

Add description information

Jérôme BUISINE 6 months ago
parent
commit
13e29e2a61
2 changed files with 5 additions and 4 deletions
  1. 2 2
      README.md
  2. 3 2
      train_lstm_weighted.py

+ 2 - 2
README.md

@@ -73,13 +73,13 @@ List of expected parameter by reconstruction method:
 
 **__Example:__**
 ```bash
-python generate/generate_dataset.py --output data/output_data_filename --features "svd_reconstruction, ipca_reconstruction, fast_ica_reconstruction" --renderer "maxwell" --scenes "A, D, G, H" --params "100, 200 :: 50, 10 :: 50" --nb_zones 10 --random 1
+python generate/generate_dataset_sequence_file.py --output data/output_data_filename --folder <generated_data_folder> --features "svd_reconstruction, ipca_reconstruction, fast_ica_reconstruction" --params "100, 200 :: 50, 10 :: 50" --sequence 5 --size "100, 100" --selected_zones <zones_files.csv>
 ```
 
 
 Then, train model using your custom dataset:
 ```bash
-python train_model.py --data data/custom_dataset --output output_model_name
+python train_lstm_model.py --train data/custom_dataset.train --test data/custom_dataset.test --chanels "1,3,3" --epochs 30 --batch_size 64 --seq_norm 1 --output output_model_name
 ```
 
 ### Predict image using model

+ 3 - 2
train_lstm_weighted.py

@@ -85,6 +85,7 @@ def build_input(df, seq_norm, p_chanels):
                 else:
                     img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                 
+                # normalization of images
                 seq_elems.append(np.array(img, 'float32') / 255.)
 
             #seq_arr.append(np.array(seq_elems).flatten())
@@ -160,10 +161,10 @@ def create_model(_input_shape):
     model.add(Dropout(0.5))
 
     model.add(Flatten())
-    model.add(Dense(512, activation='relu'))
+    model.add(Dense(128, activation='relu'))
     model.add(BatchNormalization())
     model.add(Dropout(0.5))
-    model.add(Dense(128, activation='relu'))
+    model.add(Dense(32, activation='relu'))
     model.add(BatchNormalization())
     model.add(Dropout(0.5))
     model.add(Dense(1, activation='sigmoid'))