ソースを参照

Merge branch 'release/v0.0.3'

Jérôme BUISINE 4 年 前
コミット
7a90d64b01
2 ファイル変更8 行追加22 行削除
  1. 7 5
      classes/Transformation.py
  2. 1 17
      utils/data.py

+ 7 - 5
classes/Transformation.py

@@ -1,6 +1,6 @@
 import os
 
-from ipfml.processing import svd_reconstruction, fast_ica_reconstruction, ipca_reconstruction
+from ipfml.processing import reconstruction
 
 # Transformation class to store transformation method of image and get usefull information
 class Transformation():
@@ -13,15 +13,15 @@ class Transformation():
 
         if self.transformation == 'svd_reconstruction':
             begin, end = list(map(int, self.param.split(',')))
-            data = svd_reconstruction(img, [begin, end])
+            data = reconstruction.svd(img, [begin, end])
 
         if self.transformation == 'ipca_reconstruction':
             n_components, batch_size = list(map(int, self.param.split(',')))
-            data = ipca_reconstruction(img, n_components, batch_size)
+            data = reconstruction.ipca(img, n_components, batch_size)
 
         if self.transformation == 'fast_ica_reconstruction':
             n_components = self.param
-            data = fast_ica_reconstruction(img, n_components)
+            data = reconstruction.fast_ica(img, n_components)
 
         if self.transformation == 'static':
             # static content, we keep input as it is
@@ -47,7 +47,9 @@ class Transformation():
 
         if self.transformation == 'static':
             # param contains the whole path of image
-            path = os.path.join(self.param, self.transformation)
+            last_element = self.param.split('/')[-1] 
+            output_path = self.param.replace(last_element, '')
+            path = os.path.join(output_path, self.transformation, last_element)
 
         return path
 

+ 1 - 17
utils/data.py

@@ -1,19 +1,4 @@
-from ipfml import processing, metrics, utils
 from modules.utils.config import *
-from transformation_functions import svd_reconstruction
-
-from PIL import Image
-from skimage import color
-from sklearn.decomposition import FastICA
-from sklearn.decomposition import IncrementalPCA
-from sklearn.decomposition import TruncatedSVD
-from numpy.linalg import svd as lin_svd
-
-from scipy.signal import medfilt2d, wiener, cwt
-import pywt
-
-import numpy as np
-
 
 _scenes_names_prefix   = '_scenes_names'
 _scenes_indices_prefix = '_scenes_indices'
@@ -40,5 +25,4 @@ def get_renderer_scenes_names(renderer_name):
     if renderer_name == 'all':
         return scenes_names
     else:
-        return context_vars[renderer_name + _scenes_names_prefix]
-
+        return context_vars[renderer_name + _scenes_names_prefix]