Parcourir la source

Sequence file updated

Jérôme BUISINE il y a 3 ans
Parent
commit
9fa64f550f
3 fichiers modifiés avec 34 ajouts et 11 suppressions
  1. 0 3
      .gitmodules
  2. 10 8
      generate/generate_dataset_sequence_file.py
  3. 24 0
      generate/transformations.py

+ 0 - 3
.gitmodules

@@ -1,3 +0,0 @@
-[submodule "modules"]
-	path = modules
-	url = https://github.com/prise-3d/Thesis-CommonModules.git

+ 10 - 8
generate/generate_dataset_sequence_file.py

@@ -23,12 +23,12 @@ from transformations import Transformation
 
 def generate_data_model(_filename, _transformations, _dataset_folder, _selected_zones, _sequence):
 
-    output_train_filename = os.path.join(output_data_folder, _filename, _filename + ".train")
-    output_test_filename = os.path.join(output_data_folder, _filename, _filename + ".test")
+    output_train_filename = os.path.join(cfg.output_data_folder, _filename, _filename + ".train")
+    output_test_filename = os.path.join(cfg.output_data_folder, _filename, _filename + ".test")
 
     # create path if not exists
-    if not os.path.exists(os.path.join(output_data_folder, _filename)):
-        os.makedirs(os.path.join(output_data_folder, _filename))
+    if not os.path.exists(os.path.join(cfg.output_data_folder, _filename)):
+        os.makedirs(os.path.join(cfg.output_data_folder, _filename))
 
     train_file = open(output_train_filename, 'w')
     test_file = open(output_test_filename, 'w')
@@ -81,12 +81,14 @@ def generate_data_model(_filename, _transformations, _dataset_folder, _selected_
 
                     # get image path to manage
                     # {sceneName}/static/img.png
-                    transform_image_path = os.path.join(scene_path, transformation.getName(), image_name) 
-                    static_transform_image = Image.open(transform_image_path)
+                    
+                    # TODO : check this part
+                    #transform_image_path = os.path.join(scene_path, transformation.getName(), image_name) 
+                    #static_transform_image = Image.open(transform_image_path)
 
-                    static_transform_image_block = divide_in_blocks(static_transform_image, cfg.sub_image_size)[id_zone]
+                    #static_transform_image_block = divide_in_blocks(static_transform_image, cfg.sub_image_size)[id_zone]
 
-                    dt.augmented_data_image(static_transform_image_block, image_folder_path, image_prefix_name)
+                    #transformation.augmented_data_image(static_transform_image_block, image_folder_path, image_prefix_name)
 
                 else:
                     metric_interval_path = os.path.join(zone_path, transformation.getTransformationPath())

+ 24 - 0
generate/transformations.py

@@ -1,6 +1,7 @@
 # main imports
 import os
 import numpy as np
+import random
 
 # image processing imports
 from ipfml.processing import transform, compression
@@ -66,6 +67,29 @@ def fill_image_with_rand_value(img, func, value_to_replace):
                 
     return output
 
+def augmented_data_image(block, output_folder, prefix_image_name):
+
+    rotations = [0, 90, 180, 270]
+    img_flip_labels = ['original', 'horizontal', 'vertical', 'both']
+
+    horizontal_img = block.transpose(Image.FLIP_LEFT_RIGHT)
+    vertical_img = block.transpose(Image.FLIP_TOP_BOTTOM)
+    both_img = block.transpose(Image.TRANSPOSE)
+
+    flip_images = [block, horizontal_img, vertical_img, both_img]
+
+    # rotate and flip image to increase dataset size
+    for id, flip in enumerate(flip_images):
+        for rotation in rotations:
+            rotated_output_img = flip.rotate(rotation)
+
+            output_reconstructed_filename = prefix_image_name + post_image_name_separator
+            output_reconstructed_filename = output_reconstructed_filename + img_flip_labels[id] + '_' + str(rotation) + '.png'
+            output_reconstructed_path = os.path.join(output_folder, output_reconstructed_filename)
+
+            if not os.path.exists(output_reconstructed_path):
+                rotated_output_img.save(output_reconstructed_path)
+                
 def _compute_relative_error(ref_sv, k_sv):
     ref = np.sqrt(np.sum(np.square(ref_sv)))
     k = np.sqrt(np.sum(np.square(k_sv)))