display_lab_bits_shifted_scene.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # main imports
  2. import sys, os, argparse
  3. import numpy as np
  4. import random
  5. import time
  6. import json
  7. # image processing imports
  8. from PIL import Image
  9. from skimage import color
  10. import matplotlib.pyplot as plt
  11. from ipfml.processing import compression, transform
  12. # modules and config imports
  13. sys.path.insert(0, '') # trick to enable import of main folder module
  14. import custom_config as cfg
  15. from modules.utils import data as dt
  16. # variables and parameters
  17. zone_folder = cfg.zone_folder
  18. min_max_filename = cfg.min_max_filename_extension
  19. # define all scenes values
  20. scenes_list = cfg.scenes_names
  21. scenes_indices = cfg.scenes_indices
  22. path = cfg.dataset_path
  23. zones = cfg.zones_indices
  24. seuil_expe_filename = cfg.seuil_expe_filename
  25. max_nb_bits = 8
  26. def display_data_scenes(nb_bits, p_scene):
  27. """
  28. @brief Method display shifted values for specific scene
  29. @param nb_bits, number of bits expected
  30. @param p_scene, scene we want to show values
  31. @return nothing
  32. """
  33. scenes = os.listdir(path)
  34. # remove min max file from scenes folder
  35. scenes = [s for s in scenes if min_max_filename not in s]
  36. # go ahead each scenes
  37. for folder_scene in scenes:
  38. if p_scene == folder_scene:
  39. print(folder_scene)
  40. scene_path = os.path.join(path, folder_scene)
  41. # construct each zones folder name
  42. zones_folder = []
  43. # get zones list info
  44. for index in zones:
  45. index_str = str(index)
  46. if len(index_str) < 2:
  47. index_str = "0" + index_str
  48. current_zone = "zone"+index_str
  49. zones_folder.append(current_zone)
  50. threshold_info = []
  51. for zone_folder in zones_folder:
  52. zone_path = os.path.join(scene_path, zone_folder)
  53. # get threshold information
  54. path_seuil = os.path.join(zone_path, seuil_expe_filename)
  55. # open treshold path and get this information
  56. with open(path_seuil, "r") as seuil_file:
  57. seuil_learned = int(seuil_file.readline().strip())
  58. threshold_info.append(seuil_learned)
  59. # compute mean threshold values
  60. mean_threshold = sum(threshold_info) / float(len(threshold_info))
  61. print(mean_threshold, "mean threshold found")
  62. threshold_image_found = False
  63. # get all images of folder
  64. scene_images = sorted([os.path.join(scene_path, img) for img in os.listdir(scene_path) if cfg.scene_image_extension in img])
  65. start_image_path = scene_images[0]
  66. end_image_path = scene_images[-1]
  67. start_quality_image = dt.get_scene_image_quality(scene_images[0])
  68. end_quality_image = dt.get_scene_image_quality(scene_images[-1])
  69. # for each images
  70. for img_path in scene_images:
  71. current_quality_image = dt.get_scene_image_quality(img_path)
  72. if mean_threshold < int(current_quality_image) and not threshold_image_found:
  73. threshold_image_found = True
  74. threshold_image_path = img_path
  75. threshold_image = dt.get_scene_image_quality(img_path)
  76. # all indexes of picture to plot
  77. images_path = [start_image_path, threshold_image_path, end_image_path]
  78. low_bits_svd_values = []
  79. for i in range(0, max_nb_bits - nb_bits + 1):
  80. low_bits_svd_values.append([])
  81. for img_path in images_path:
  82. current_img = Image.open(img_path)
  83. block_used = np.array(current_img)
  84. low_bits_block = transform.rgb_to_LAB_L_bits(block_used, (i + 1, i + nb_bits + 1))
  85. low_bits_svd = compression.get_SVD_s(low_bits_block)
  86. low_bits_svd = [b / low_bits_svd[0] for b in low_bits_svd]
  87. low_bits_svd_values[i].append(low_bits_svd)
  88. fig=plt.figure(figsize=(8, 8))
  89. fig.suptitle("Lab SVD " + str(nb_bits) + " bits values shifted for " + p_scene + " scene", fontsize=20)
  90. for id, data in enumerate(low_bits_svd_values):
  91. fig.add_subplot(3, 3, (id + 1))
  92. plt.plot(data[0], label='Noisy_' + start_quality_image)
  93. plt.plot(data[1], label='Threshold_' + threshold_image)
  94. plt.plot(data[2], label='Reference_' + end_quality_image)
  95. plt.ylabel('Lab SVD ' + str(nb_bits) + ' bits values shifted by ' + str(id), fontsize=14)
  96. plt.xlabel('Vector features', fontsize=16)
  97. plt.legend(bbox_to_anchor=(0.5, 1), loc=2, borderaxespad=0.2, fontsize=14)
  98. plt.ylim(0, 0.1)
  99. plt.show()
  100. def main():
  101. parser = argparse.ArgumentParser(description="Display curves of shifted bits influence of L canal on specific scene")
  102. parser.add_argument('--bits', type=str, help='Number of bits to display')
  103. parser.add_argument('--scene', type=str, help="scene index to use", choices=scenes_indices)
  104. args = parser.parse_args()
  105. p_bits = args.bits
  106. p_scene = scenes_list[scenes_indices.index(args.scene)]
  107. display_data_scenes(p_bits, p_scene)
  108. if __name__== "__main__":
  109. main()