Transformation.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # main imports
  2. import os
  3. import numpy as np
  4. # image processing imports
  5. from ipfml.processing import transform
  6. from ipfml.processing import reconstruction
  7. from ipfml.filters import convolution, kernels
  8. from ipfml import utils
  9. from PIL import Image
  10. # Transformation class to store transformation method of image and get usefull information
  11. class Transformation():
  12. def __init__(self, _transformation, _param, _size):
  13. self.transformation = _transformation
  14. self.param = _param
  15. self.size = _size
  16. def getTransformedImage(self, img):
  17. if self.transformation == 'svd_reconstruction':
  18. begin, end = list(map(int, self.param.split(',')))
  19. h, w = list(map(int, self.size.split(',')))
  20. img_reconstructed = reconstruction.svd(img, [begin, end])
  21. data_array = np.array(img_reconstructed, 'uint8')
  22. img_array = Image.fromarray(data_array)
  23. img_array.thumbnail((h, w))
  24. data = np.array(img_array)
  25. if self.transformation == 'ipca_reconstruction':
  26. n_components, batch_size = list(map(int, self.param.split(',')))
  27. h, w = list(map(int, self.size.split(',')))
  28. img_reconstructed = reconstruction.ipca(img, n_components, batch_size)
  29. data_array = np.array(img_reconstructed, 'uint8')
  30. img_array = Image.fromarray(data_array)
  31. img_array.thumbnail((h, w))
  32. data = np.array(img_array)
  33. if self.transformation == 'fast_ica_reconstruction':
  34. n_components = self.param
  35. h, w = list(map(int, self.size.split(',')))
  36. img_reconstructed = reconstruction.fast_ica(img, n_components)
  37. data_array = np.array(img_reconstructed, 'uint8')
  38. img_array = Image.fromarray(data_array)
  39. img_array.thumbnail((h, w))
  40. data = np.array(img_array)
  41. if self.transformation == 'min_diff_filter':
  42. w_size, h_size, stride = list(map(int, self.param.split(',')))
  43. h, w = list(map(int, self.size.split(',')))
  44. # bilateral with window of size (`w_size`, `h_size`)
  45. lab_img = transform.get_LAB_L(img)
  46. img_filter = convolution.convolution2D(lab_img, kernels.min_bilateral_diff, (w_size, h_size), stride)
  47. diff_array = np.array(img_filter*255, 'uint8')
  48. diff_img = Image.fromarray(diff_array)
  49. diff_img.thumbnail((h, w))
  50. data = np.array(diff_img)
  51. if self.transformation == 'static':
  52. # static content, we keep input as it is
  53. data = img
  54. return data
  55. def getTransformationPath(self):
  56. path = self.transformation
  57. if self.transformation == 'svd_reconstruction':
  58. begin, end = list(map(int, self.param.split(',')))
  59. w, h = list(map(int, self.size.split(',')))
  60. path = os.path.join(path, str(begin) + '_' + str(end)) + '_S_' + str(w) + '_' + str(h)
  61. if self.transformation == 'ipca_reconstruction':
  62. n_components, batch_size = list(map(int, self.param.split(',')))
  63. w, h = list(map(int, self.size.split(',')))
  64. path = os.path.join(path, 'N' + str(n_components) + '_' + str(batch_size)) + '_S_' + str(w) + '_' + str(h)
  65. if self.transformation == 'fast_ica_reconstruction':
  66. n_components = self.param
  67. w, h = list(map(int, self.size.split(',')))
  68. path = os.path.join(path, 'N' + str(n_components)) + '_S_' + str(w) + '_' + str(h)
  69. if self.transformation == 'min_diff_filter':
  70. w_size, h_size, stride = list(map(int, self.param.split(',')))
  71. w, h = list(map(int, self.size.split(',')))
  72. path = os.path.join(path, 'W_' + str(w_size)) + '_' + str(h_size) + '_Stride_' + str(stride) + '_S_' + str(w) + '_' + str(h)
  73. if self.transformation == 'static':
  74. # param contains image name to find for each scene
  75. path = self.param
  76. return path
  77. def getName(self):
  78. return self.transformation
  79. def getParam(self):
  80. return self.param
  81. def __str__( self ):
  82. return self.transformation + ' transformation with parameter : ' + self.param