noise_svd_tend_visualization.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import sys, os, getopt
  2. from PIL import Image
  3. from ipfml import processing, utils
  4. import ipfml.iqa.fr as fr_iqa
  5. from modules.utils import config as cfg
  6. from modules.utils import data_type as dt
  7. from modules import noise
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. plt.style.use('ggplot')
  11. noise_list = cfg.noise_labels
  12. generated_folder = cfg.generated_folder
  13. filename_ext = cfg.filename_ext
  14. metric_choices = cfg.metric_choices_labels
  15. normalization_choices = cfg.normalization_choices
  16. pictures_folder = cfg.pictures_output_folder
  17. step_picture = 10
  18. error_data_choices = ['mae', 'mse', 'ssim', 'psnr']
  19. def get_error_distance(p_error, y_true, y_test):
  20. noise_method = None
  21. function_name = p_error
  22. try:
  23. error_method = getattr(fr_iqa, function_name)
  24. except AttributeError:
  25. raise NotImplementedError("Error method `{}` not implement `{}`".format(fr_iqa.__name__, function_name))
  26. return error_method(y_true, y_test)
  27. def main():
  28. # default values
  29. p_step = 1
  30. p_color = 0
  31. p_norm = 0
  32. p_ylim = (0, 1)
  33. max_value_svd = 0
  34. min_value_svd = sys.maxsize
  35. if len(sys.argv) <= 1:
  36. print('python noise_svd_tend_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error mae')
  37. sys.exit(2)
  38. try:
  39. opts, args = getopt.getopt(sys.argv[1:], "h:p:m:m:n:i:s:c:n:y:e", ["help=", "prefix=", "metric=", "mode=", "n=", "interval=", "step=", "color=", "norm=", "ylim=", "error="])
  40. except getopt.GetoptError:
  41. # print help information and exit:
  42. print('python noise_svd_tend_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error mae')
  43. sys.exit(2)
  44. for o, a in opts:
  45. if o == "-h":
  46. print('python noise_svd_tend_visualization.py --prefix generated/prefix/noise --metric lab --mode svdn --n 300 --interval "0, 200" --step 30 --color 1 --norm 1 --ylim "0, 1" --error MAE')
  47. sys.exit()
  48. elif o in ("-p", "--prefix"):
  49. p_path = a
  50. elif o in ("-m", "--mode"):
  51. p_mode = a
  52. if not p_mode in normalization_choices:
  53. assert False, "Unknown normalization choice, %s" % normalization_choices
  54. elif o in ("-m", "--metric"):
  55. p_metric = a
  56. if not p_metric in metric_choices:
  57. assert False, "Unknown metric choice, %s" % metric_choices
  58. elif o in ("-n", "--n"):
  59. p_n = int(a)
  60. elif o in ("-n", "--norm"):
  61. p_norm = int(a)
  62. elif o in ("-c", "--color"):
  63. p_color = int(a)
  64. elif o in ("-i", "--interval"):
  65. p_interval = list(map(int, a.split(',')))
  66. elif o in ("-s", "--step"):
  67. p_step = int(a)
  68. elif o in ("-y", "--ylim"):
  69. p_ylim = list(map(float, a.split(',')))
  70. elif o in ("-e", "--error"):
  71. p_error = a
  72. if p_error not in error_data_choices:
  73. assert False, "Unknow error choice to display %s" % error_data_choices
  74. else:
  75. assert False, "unhandled option"
  76. p_prefix = p_path.split('/')[1].replace('_', '')
  77. noise_name = p_path.split('/')[2]
  78. if p_color:
  79. file_path = os.path.join(p_path, p_prefix + "_" + noise_name + "_color_{}." + filename_ext)
  80. else:
  81. file_path = os.path.join(p_path, p_prefix + "_" + noise_name + "_{}." + filename_ext)
  82. begin, end = p_interval
  83. all_svd_data = []
  84. svd_data = []
  85. image_indices = []
  86. noise_indices = range(1, p_n)[::-1]
  87. # get all data from images
  88. for i in noise_indices:
  89. if i % step_picture == 0:
  90. image_path = file_path.format(str(i))
  91. img = Image.open(image_path)
  92. svd_values = dt.get_svd_data(p_metric, img)
  93. if p_norm:
  94. svd_values = svd_values[begin:end]
  95. all_svd_data.append(svd_values)
  96. # update min max values
  97. min_value = svd_values.min()
  98. max_value = svd_values.max()
  99. if min_value < min_value_svd:
  100. min_value_svd = min_value
  101. if max_value > max_value_svd:
  102. max_value_svd = max_value
  103. print('%.2f%%' % ((p_n - i + 1) / p_n * 100))
  104. sys.stdout.write("\033[F")
  105. previous_data = []
  106. error_data = [0.]
  107. for id, data in enumerate(all_svd_data):
  108. current_id = (p_n - ((id + 1) * 10))
  109. if current_id % p_step == 0:
  110. current_data = data
  111. if p_mode == 'svdn':
  112. current_data = utils.normalize_arr(current_data)
  113. if p_mode == 'svdne':
  114. current_data = utils.normalize_arr_with_range(current_data, min_value_svd, max_value_svd)
  115. svd_data.append(current_data)
  116. image_indices.append(current_id)
  117. # use of whole image data for computation of ssim or psnr
  118. if p_error == 'ssim' or p_error == 'psnr':
  119. image_path = file_path.format(str(current_id))
  120. current_data = np.asarray(Image.open(image_path))
  121. if len(previous_data) > 0:
  122. current_error = get_error_distance(p_error, previous_data, current_data)
  123. error_data.append(current_error)
  124. if len(previous_data) == 0:
  125. previous_data = current_data
  126. # display all data using matplotlib (configure plt)
  127. gridsize = (3, 2)
  128. # fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(30, 22))
  129. fig = plt.figure(figsize=(30, 22))
  130. ax1 = plt.subplot2grid(gridsize, (0, 0), colspan=2, rowspan=2)
  131. ax2 = plt.subplot2grid(gridsize, (2, 0), colspan=2)
  132. ax1.set_title(p_prefix + ', ' + noise_name + ' noise, interval information ['+ str(begin) +', '+ str(end) +'], ' + p_metric + ' metric, step ' + str(p_step) + ' normalization ' + p_mode)
  133. ax1.set_label('Importance of noise [1, 999]')
  134. ax1.set_xlabel('Vector features')
  135. for id, data in enumerate(svd_data):
  136. p_label = p_prefix + str(image_indices[id]) + " | " + p_error + ": " + str(error_data[id])
  137. ax1.plot(data, label=p_label)
  138. ax1.legend(bbox_to_anchor=(0.8, 1), loc=2, borderaxespad=0.2, fontsize=12)
  139. if not p_norm:
  140. ax1.set_xlim(begin, end)
  141. # adapt ylim
  142. y_begin, y_end = p_ylim
  143. ax1.set_ylim(y_begin, y_end)
  144. output_filename = p_prefix + "_" + noise_name + "_1_to_" + str(p_n) + "_B" + str(begin) + "_E" + str(end) + "_" + p_metric + "_S" + str(p_step) + "_norm" + str(p_norm )+ "_" + p_mode + "_" + p_error
  145. if p_color:
  146. output_filename = output_filename + '_color'
  147. ax2.set_title(p_error + " information for : " + p_prefix + ', ' + noise_name + ' noise, interval information ['+ str(begin) +', '+ str(end) +'], ' + p_metric + ' metric, step ' + str(p_step) + ', normalization ' + p_mode)
  148. ax2.set_ylabel(p_error + ' error')
  149. ax2.set_xlabel('Number of samples per pixels')
  150. ax2.set_xticks(range(len(image_indices)))
  151. ax2.set_xticklabels(image_indices)
  152. ax2.plot(error_data)
  153. print("Generation of output figure... %s" % output_filename)
  154. output_path = os.path.join(pictures_folder, output_filename)
  155. if not os.path.exists(pictures_folder):
  156. os.makedirs(pictures_folder)
  157. fig.savefig(output_path, dpi=(200))
  158. if __name__== "__main__":
  159. main()