Transformation.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. from ipfml.processing import reconstruction
  3. from ipfml.filters import convolution, kernels
  4. from ipfml import utils
  5. # Transformation class to store transformation method of image and get usefull information
  6. class Transformation():
  7. def __init__(self, _transformation, _param):
  8. self.transformation = _transformation
  9. self.param = _param
  10. def getTransformedImage(self, img):
  11. if self.transformation == 'svd_reconstruction':
  12. begin, end = list(map(int, self.param.split(',')))
  13. data = reconstruction.svd(img, [begin, end])
  14. if self.transformation == 'ipca_reconstruction':
  15. n_components, batch_size = list(map(int, self.param.split(',')))
  16. data = reconstruction.ipca(img, n_components, batch_size)
  17. if self.transformation == 'fast_ica_reconstruction':
  18. n_components = self.param
  19. data = reconstruction.fast_ica(img, n_components)
  20. if self.transformation == 'diff_filter':
  21. w_size, h_size = list(map(int, self.param.split(',')))
  22. # bilateral with window of size (`w_size`, `h_size`)
  23. data = convolution.convolution2D(img, kernels.bilateral_diff, (w_size, h_size))
  24. if self.transformation == 'static':
  25. # static content, we keep input as it is
  26. data = img
  27. return data
  28. def getTransformationPath(self):
  29. path = self.transformation
  30. if self.transformation == 'svd_reconstruction':
  31. begin, end = list(map(int, self.param.split(',')))
  32. path = os.path.join(path, str(begin) + '_' + str(end))
  33. if self.transformation == 'ipca_reconstruction':
  34. n_components, batch_size = list(map(int, self.param.split(',')))
  35. path = os.path.join(path, 'N' + str(n_components) + '_' + str(batch_size))
  36. if self.transformation == 'fast_ica_reconstruction':
  37. n_components = self.param
  38. path = os.path.join(path, 'N' + str(n_components))
  39. if self.transformation == 'diff_filter':
  40. w_size, h_size = list(map(int, self.param.split(',')))
  41. path = os.path.join(path, 'W_' + str(w_size)) + '_' + str(h_size)
  42. if self.transformation == 'static':
  43. # param contains image name to find for each scene
  44. path = self.param
  45. return path
  46. def getName(self):
  47. return self.transformation
  48. def getParam(self):
  49. return self.param
  50. def __str__( self ):
  51. return self.transformation + ' transformation with parameter : ' + self.param