Parcourir la source

Renormalization scripts updates

Jérôme BUISINE il y a 5 ans
Parent
commit
76553e971f
2 fichiers modifiés avec 48 ajouts et 28 suppressions
  1. 23 7
      generate_data_model_random.py
  2. 25 21
      predict_noisy_image_svd.py

+ 23 - 7
generate_data_model_random.py

@@ -28,7 +28,7 @@ min_max_filename        = cfg.min_max_filename_extension
 all_scenes_list         = cfg.scenes_names
 all_scenes_indices      = cfg.scenes_indices
 
-choices                 = cfg.normalization_choices
+normalization_choices   = cfg.normalization_choices
 path                    = cfg.dataset_path
 zones                   = cfg.zones_indices
 seuil_expe_filename     = cfg.seuil_expe_filename
@@ -76,7 +76,7 @@ def construct_new_line(path_seuil, interval, line, choice, norm):
 
     return line
 
-def get_min_max_value_interval(_scenes_list, _filename, _interval, _choice, _metric):
+def get_min_max_value_interval(_scenes_list, _filename, _interval, _choice, _metric, _custom):
 
     global min_value_interval, max_value_interval
 
@@ -106,7 +106,13 @@ def get_min_max_value_interval(_scenes_list, _filename, _interval, _choice, _met
             for id_zone, zone_folder in enumerate(zones_folder):
 
                 zone_path = os.path.join(scene_path, zone_folder)
-                data_filename = _metric + "_" + _choice + generic_output_file_svd
+
+                # if custom normalization choices then we use svd values not already normalized
+                if _custom:
+                    data_filename = _metric + "_svd"+ generic_output_file_svd
+                else:
+                    data_filename = _metric + "_" + _choice + generic_output_file_svd
+
                 data_file_path = os.path.join(zone_path, data_filename)
 
                 # getting number of line and read randomly lines
@@ -136,7 +142,7 @@ def get_min_max_value_interval(_scenes_list, _filename, _interval, _choice, _met
                     counter += 1
 
 
-def generate_data_model(_scenes_list, _filename, _interval, _choice, _metric, _scenes, _nb_zones = 4, _percent = 1, _random=0, _step=1, _norm = False):
+def generate_data_model(_scenes_list, _filename, _interval, _choice, _metric, _scenes, _nb_zones = 4, _percent = 1, _random=0, _step=1, _custom = False):
 
     output_train_filename = _filename + ".train"
     output_test_filename = _filename + ".test"
@@ -178,7 +184,13 @@ def generate_data_model(_scenes_list, _filename, _interval, _choice, _metric, _s
 
             for id_zone, zone_folder in enumerate(zones_folder):
                 zone_path = os.path.join(scene_path, zone_folder)
-                data_filename = _metric + "_" + _choice + generic_output_file_svd
+
+                # if custom normalization choices then we use svd values not already normalized
+                if _custom:
+                    data_filename = _metric + "_svd"+ generic_output_file_svd
+                else:
+                    data_filename = _metric + "_" + _choice + generic_output_file_svd
+
                 data_file_path = os.path.join(zone_path, data_filename)
 
                 # getting number of line and read randomly lines
@@ -201,7 +213,7 @@ def generate_data_model(_scenes_list, _filename, _interval, _choice, _metric, _s
                     image_index = int(data.split(';')[0])
 
                     if image_index % _step == 0:
-                        line = construct_new_line(path_seuil, _interval, data, _choice, _norm)
+                        line = construct_new_line(path_seuil, _interval, data, _choice, _custom)
 
                         if id_zone < _nb_zones and folder_scene in _scenes and percent <= _percent:
                             train_file_data.append(line)
@@ -251,6 +263,10 @@ def main():
             p_interval = list(map(int, a.split(',')))
         elif o in ("-k", "--kind"):
             p_kind = a
+
+            if p_kind not in normalization_choices:
+                assert False, "Invalid normalization choice, %s" % normalization_choices
+
         elif o in ("-m", "--metric"):
             p_metric = a
         elif o in ("-s", "--scenes"):
@@ -288,7 +304,7 @@ def main():
 
     # find min max value if necessary to renormalize data
     if p_custom:
-        get_min_max_value_interval(scenes_list, p_filename, p_interval, p_kind, p_metric)
+        get_min_max_value_interval(scenes_list, p_filename, p_interval, p_kind, p_metric, p_custom)
 
         # write new file to save
         if not os.path.exists(custom_min_max_folder):

+ 25 - 21
predict_noisy_image_svd.py

@@ -68,30 +68,11 @@ def main():
     # get interval values
     begin, end = p_interval
 
-    # check mode to normalize data
-    if p_mode == 'svdne':
-
-        # set min_max_filename if custom use
-        min_max_file_path = path + '/' + p_metric + min_max_ext
-
-        # need to read min_max_file
-        file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
-        with open(file_path, 'r') as f:
-            min_val = float(f.readline().replace('\n', ''))
-            max_val = float(f.readline().replace('\n', ''))
-
-        l_values = processing.normalize_arr_with_range(data, min_val, max_val)
-
-    elif p_mode == 'svdn':
-        l_values = processing.normalize_arr(data)
-    else:
-        l_values = data
-
-    test_data = l_values[begin:end]
-
     # check if custom min max file is used
     if p_custom:
 
+        test_data = data[begin:end]
+
         if p_mode == 'svdne':
 
             # set min_max_filename if custom use
@@ -108,6 +89,29 @@ def main():
         if p_mode == 'svdn':
             test_data = processing.normalize_arr(test_data)
 
+    else:
+
+        # check mode to normalize data
+        if p_mode == 'svdne':
+
+            # set min_max_filename if custom use
+            min_max_file_path = path + '/' + p_metric + min_max_ext
+
+            # need to read min_max_file
+            file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
+            with open(file_path, 'r') as f:
+                min_val = float(f.readline().replace('\n', ''))
+                max_val = float(f.readline().replace('\n', ''))
+
+            l_values = processing.normalize_arr_with_range(data, min_val, max_val)
+
+        elif p_mode == 'svdn':
+            l_values = processing.normalize_arr(data)
+        else:
+            l_values = data
+
+        test_data = l_values[begin:end]
+
 
     # get prediction of model
     prediction = model.predict([test_data])[0]