transformations.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # main imports
  2. import os
  3. import numpy as np
  4. import random
  5. # image processing imports
  6. from ipfml.processing import transform, compression
  7. from ipfml.processing import reconstruction
  8. from ipfml.filters import convolution, kernels
  9. from ipfml import utils
  10. import cv2
  11. from skimage.restoration import denoise_nl_means, estimate_sigma
  12. from PIL import Image
  13. def remove_pixel(img, limit):
  14. width, height = img.shape
  15. output = np.zeros((width, height))
  16. for i in range(width):
  17. for j in range(height):
  18. if img[i,j] <= limit:
  19. output[i,j] = img[i,j]
  20. return output
  21. def get_random_value(distribution):
  22. rand = random.uniform(0, 1)
  23. prob_sum = 0.
  24. for id, prob in enumerate(distribution):
  25. prob_sum += prob
  26. if prob_sum >= rand:
  27. return id
  28. return len(distribution) - 1
  29. def distribution_from_data(data):
  30. occurences = np.array([data.count(x) for x in set(data)])
  31. max_occurences = sum(occurences)
  32. return occurences / max_occurences
  33. def fill_image_with_rand_value(img, func, value_to_replace):
  34. width, height = img.shape
  35. output = np.zeros((width, height))
  36. for i in range(width):
  37. for j in range(height):
  38. if img[i,j] == value_to_replace:
  39. output[i, j] = func()
  40. else:
  41. output[i, j] = img[i, j]
  42. return output
  43. def augmented_data_image(block, output_folder, prefix_image_name):
  44. rotations = [0, 90, 180, 270]
  45. img_flip_labels = ['original', 'horizontal', 'vertical', 'both']
  46. horizontal_img = block.transpose(Image.FLIP_LEFT_RIGHT)
  47. vertical_img = block.transpose(Image.FLIP_TOP_BOTTOM)
  48. both_img = block.transpose(Image.TRANSPOSE)
  49. flip_images = [block, horizontal_img, vertical_img, both_img]
  50. # rotate and flip image to increase dataset size
  51. for id, flip in enumerate(flip_images):
  52. for rotation in rotations:
  53. rotated_output_img = flip.rotate(rotation)
  54. output_reconstructed_filename = prefix_image_name + post_image_name_separator
  55. output_reconstructed_filename = output_reconstructed_filename + img_flip_labels[id] + '_' + str(rotation) + '.png'
  56. output_reconstructed_path = os.path.join(output_folder, output_reconstructed_filename)
  57. if not os.path.exists(output_reconstructed_path):
  58. rotated_output_img.save(output_reconstructed_path)
  59. def _compute_relative_error(ref_sv, k_sv):
  60. ref = np.sqrt(np.sum(np.square(ref_sv)))
  61. k = np.sqrt(np.sum(np.square(k_sv)))
  62. return k / ref
  63. def _find_n_components(block, e=0.1):
  64. s = transform.get_LAB_L_SVD_s(block)
  65. errors = []
  66. found = False
  67. k_components = None
  68. for i in range(len(s)):
  69. #Ak = reconstruction.svd(img, [0, i])
  70. #error = compute_relative_error_matrix(A, Ak)
  71. error = _compute_relative_error(s, s[i:])
  72. errors.append(error)
  73. if error < e and not found:
  74. k_components = (i + 1)
  75. found = True
  76. return (k_components, errors)
  77. # Transformation class to store transformation method of image and get usefull information
  78. class Transformation():
  79. def __init__(self, _transformation, _param, _size):
  80. self.transformation = _transformation
  81. self.param = _param
  82. self.size = _size
  83. def getTransformedImage(self, img):
  84. if self.transformation == 'svd_reconstruction':
  85. begin, end = list(map(int, self.param.split(',')))
  86. h, w = list(map(int, self.size.split(',')))
  87. img_reconstructed = reconstruction.svd(img, [begin, end])
  88. data_array = np.array(img_reconstructed, 'uint8')
  89. img_array = Image.fromarray(data_array)
  90. img_array.thumbnail((h, w))
  91. data = np.array(img_array)
  92. if self.transformation == 'svd_reconstruction':
  93. begin, end = list(map(int, self.param.split(',')))
  94. h, w = list(map(int, self.size.split(',')))
  95. img_reconstructed = reconstruction.svd(img, [begin, end])
  96. data_array = np.array(img_reconstructed, 'uint8')
  97. img_array = Image.fromarray(data_array)
  98. img_array.thumbnail((h, w))
  99. data = np.array(img_array)
  100. if self.transformation == 'ipca_reconstruction':
  101. n_components, batch_size = list(map(int, self.param.split(',')))
  102. h, w = list(map(int, self.size.split(',')))
  103. img_reconstructed = reconstruction.ipca(img, n_components, batch_size)
  104. data_array = np.array(img_reconstructed, 'uint8')
  105. img_array = Image.fromarray(data_array)
  106. img_array.thumbnail((h, w))
  107. data = np.array(img_array)
  108. if self.transformation == 'fast_ica_reconstruction':
  109. n_components = self.param
  110. h, w = list(map(int, self.size.split(',')))
  111. img_reconstructed = reconstruction.fast_ica(img, n_components)
  112. data_array = np.array(img_reconstructed, 'uint8')
  113. img_array = Image.fromarray(data_array)
  114. img_array.thumbnail((h, w))
  115. data = np.array(img_array)
  116. if self.transformation == 'gini_map':
  117. # kernel size
  118. k_w, k_h = list(map(int, self.param.split(',')))
  119. h, w = list(map(int, self.size.split(',')))
  120. lab_img = transform.get_LAB_L(img)
  121. img_mask = convolution.convolution2D(lab_img, kernels.gini, (k_w, k_h))
  122. # renormalize data
  123. data_array = np.array(img_mask * 255, 'uint8')
  124. img_array = Image.fromarray(data_array)
  125. img_array.thumbnail((h, w))
  126. data = np.array(img_array)
  127. if self.transformation == 'sobel_based_filter':
  128. k_size, p_limit = list(map(int, self.param.split(',')))
  129. h, w = list(map(int, self.size.split(',')))
  130. lab_img = transform.get_LAB_L(img)
  131. weight, height = lab_img.shape
  132. sobelx = cv2.Sobel(lab_img, cv2.CV_64F, 1, 0, ksize=k_size)
  133. sobely = cv2.Sobel(lab_img, cv2.CV_64F, 0, 1,ksize=k_size)
  134. sobel_mag = np.array(np.hypot(sobelx, sobely), 'uint8') # magnitude
  135. sobel_mag_limit = remove_pixel(sobel_mag, p_limit)
  136. # use distribution value of pixel to fill `0` values
  137. sobel_mag_limit_without_0 = [x for x in sobel_mag_limit.reshape((weight*height)) if x != 0]
  138. distribution = distribution_from_data(sobel_mag_limit_without_0)
  139. min_value = int(min(sobel_mag_limit_without_0))
  140. l = lambda: get_random_value(distribution) + min_value
  141. img_reconstructed = fill_image_with_rand_value(sobel_mag_limit, l, 0)
  142. img_reconstructed_norm = utils.normalize_2D_arr(img_reconstructed)
  143. img_reconstructed_norm = np.array(img_reconstructed_norm*255, 'uint8')
  144. sobel_reconstructed = Image.fromarray(img_reconstructed_norm)
  145. sobel_reconstructed.thumbnail((h, w))
  146. data = np.array(sobel_reconstructed)
  147. if self.transformation == 'nl_mean_noise_mask':
  148. patch_size, patch_distance = list(map(int, self.param.split(',')))
  149. h, w = list(map(int, self.size.split(',')))
  150. img = np.array(img)
  151. sigma_est = np.mean(estimate_sigma(img, multichannel=True))
  152. patch_kw = dict(patch_size=patch_size, # 5x5 patches
  153. patch_distance=patch_distance, # 13x13 search area
  154. multichannel=True)
  155. # slow algorithm
  156. denoise = denoise_nl_means(img, h=0.8 * sigma_est, sigma=sigma_est,
  157. fast_mode=False,
  158. **patch_kw)
  159. denoise = np.array(denoise, 'uint8')
  160. noise_mask = np.abs(denoise - img)
  161. data_array = np.array(noise_mask, 'uint8')
  162. img_array = Image.fromarray(data_array)
  163. img_array.thumbnail((h, w))
  164. data = np.array(img_array)
  165. if self.transformation == 'static':
  166. # static content, we keep input as it is
  167. data = img
  168. return data
  169. def getTransformationPath(self):
  170. path = self.transformation
  171. if self.transformation == 'svd_reconstruction':
  172. begin, end = list(map(int, self.param.split(',')))
  173. w, h = list(map(int, self.size.split(',')))
  174. path = os.path.join(path, str(begin) + '_' + str(end) + '_S_' + str(w) + '_' + str(h))
  175. if self.transformation == 'gini_map':
  176. k_w, k_h = list(map(int, self.param.split(',')))
  177. w, h = list(map(int, self.size.split(',')))
  178. path = os.path.join(path, str(k_w) + '_' + str(k_h) + '_S_' + str(w) + '_' + str(h))
  179. if self.transformation == 'ipca_reconstruction':
  180. n_components, batch_size = list(map(int, self.param.split(',')))
  181. w, h = list(map(int, self.size.split(',')))
  182. path = os.path.join(path, 'N' + str(n_components) + '_' + str(batch_size) + '_S_' + str(w) + '_' + str(h))
  183. if self.transformation == 'fast_ica_reconstruction':
  184. n_components = self.param
  185. w, h = list(map(int, self.size.split(',')))
  186. path = os.path.join(path, 'N' + str(n_components) + '_S_' + str(w) + '_' + str(h))
  187. if self.transformation == 'min_diff_filter':
  188. w_size, h_size, stride = list(map(int, self.param.split(',')))
  189. w, h = list(map(int, self.size.split(',')))
  190. path = os.path.join(path, 'W_' + str(w_size) + '_' + str(h_size) + '_Stride_' + str(stride) + '_S_' + str(w) + '_' + str(h))
  191. if self.transformation == 'sobel_based_filter':
  192. k_size, p_limit = list(map(int, self.param.split(',')))
  193. h, w = list(map(int, self.size.split(',')))
  194. path = os.path.join(path, 'K_' + str(k_size) + '_L' + str(p_limit) + '_S_' + str(w) + '_' + str(h))
  195. if self.transformation == 'nl_mean_noise_mask':
  196. patch_size, patch_distance = list(map(int, self.param.split(',')))
  197. h, w = list(map(int, self.size.split(',')))
  198. path = os.path.join(path, 'S' + str(patch_size) + '_D' + str(patch_distance) + '_S_' + str(w) + '_' + str(h))
  199. if self.transformation == 'static':
  200. # param contains image name to find for each scene
  201. path = self.param
  202. return path
  203. def getName(self):
  204. return self.transformation
  205. def getParam(self):
  206. return self.param
  207. def __str__( self ):
  208. return self.transformation + ' transformation with parameter : ' + self.param