Parcourir la source

Reconstruction updates

Jérôme BUISINE il y a 4 ans
Parent
commit
4eed3b9bb1
1 fichiers modifiés avec 12 ajouts et 3 suppressions
  1. 12 3
      classes/Transformation.py

+ 12 - 3
classes/Transformation.py

@@ -23,14 +23,20 @@ class Transformation():
 
         if self.transformation == 'svd_reconstruction':
             begin, end = list(map(int, self.param.split(',')))
+            h, w = list(map(int, self.size.split(',')))
+            img.thumbnail((h, w))
             data = reconstruction.svd(img, [begin, end])
 
         if self.transformation == 'ipca_reconstruction':
             n_components, batch_size = list(map(int, self.param.split(',')))
+            h, w = list(map(int, self.size.split(',')))
+            img.thumbnail((h, w))
             data = reconstruction.ipca(img, n_components, batch_size)
 
         if self.transformation == 'fast_ica_reconstruction':
             n_components = self.param
+            h, w = list(map(int, self.size.split(',')))
+            img.thumbnail((h, w))
             data = reconstruction.fast_ica(img, n_components)
 
         if self.transformation == 'min_diff_filter':
@@ -59,15 +65,18 @@ class Transformation():
 
         if self.transformation == 'svd_reconstruction':
             begin, end = list(map(int, self.param.split(',')))
-            path = os.path.join(path, str(begin) + '_' + str(end))
+            w, h = list(map(int, self.size.split(',')))
+            path = os.path.join(path, str(begin) + '_' + str(end)) + '_S_' + str(w) + '_' + str(h)
 
         if self.transformation == 'ipca_reconstruction':
             n_components, batch_size = list(map(int, self.param.split(',')))
-            path = os.path.join(path, 'N' + str(n_components) + '_' + str(batch_size))
+            w, h = list(map(int, self.size.split(',')))
+            path = os.path.join(path, 'N' + str(n_components) + '_' + str(batch_size)) + '_S_' + str(w) + '_' + str(h)
 
         if self.transformation == 'fast_ica_reconstruction':
             n_components = self.param
-            path = os.path.join(path, 'N' + str(n_components))
+            w, h = list(map(int, self.size.split(',')))
+            path = os.path.join(path, 'N' + str(n_components)) + '_S_' + str(w) + '_' + str(h)
 
         if self.transformation == 'min_diff_filter':
             w_size, h_size = list(map(int, self.param.split(',')))