Parcourir la source

Merge branch 'release/v0.2.4'

Jérôme BUISINE il y a 4 ans
Parent
commit
c75ab51d31
1 fichiers modifiés avec 19 ajouts et 3 suppressions
  1. 19 3
      classes/Transformation.py

+ 19 - 3
classes/Transformation.py

@@ -1,16 +1,23 @@
+# main imports
 import os
+import numpy as np
 
+# image processing imports
 from ipfml.processing import transform
 from ipfml.processing import reconstruction
 from ipfml.filters import convolution, kernels
 from ipfml import utils
 
+from PIL import Image
+
+
 # Transformation class to store transformation method of image and get usefull information
 class Transformation():
 
-    def __init__(self, _transformation, _param):
+    def __init__(self, _transformation, _param, _size):
         self.transformation = _transformation
         self.param = _param
+        self.size = _size
 
     def getTransformedImage(self, img):
 
@@ -28,9 +35,17 @@ class Transformation():
 
         if self.transformation == 'min_diff_filter':
             w_size, h_size = list(map(int, self.param.split(',')))
+            h, w = list(map(int, self.size.split(',')))
+
             # bilateral with window of size (`w_size`, `h_size`)
             lab_img = transform.get_LAB_L(img)
-            data = convolution.convolution2D(lab_img, kernels.min_bilateral_diff, (w_size, h_size))
+
+            lab_img = Image.fromarray(lab_img)
+            lab_img.thumbnail((h, w))
+
+            diff_img = convolution.convolution2D(lab_img, kernels.min_bilateral_diff, (w_size, h_size))
+
+            data = np.array(diff_img*255, 'uint8')
             
         if self.transformation == 'static':
             # static content, we keep input as it is
@@ -56,7 +71,8 @@ class Transformation():
 
         if self.transformation == 'min_diff_filter':
             w_size, h_size = list(map(int, self.param.split(',')))
-            path = os.path.join(path, 'W_' + str(w_size)) + '_' + str(h_size)
+            w, h = list(map(int, self.size.split(',')))
+            path = os.path.join(path, 'W_' + str(w_size)) + '_' + str(h_size) + '_S_' + str(w) + '_' + str(h)
 
         if self.transformation == 'static':
             # param contains image name to find for each scene