display_scenes_zones_shifted.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # main imports
  2. import sys, os, argparse
  3. import numpy as np
  4. import random
  5. import time
  6. import json
  7. # image processing imports
  8. from PIL import Image
  9. from skimage import color
  10. import matplotlib.pyplot as plt
  11. from ipfml.processing import segmentation, transform, compression
  12. from ipfml import utils
  13. # modules and config imports
  14. sys.path.insert(0, '') # trick to enable import of main folder module
  15. import custom_config as cfg
  16. from modules.utils import data as dt
  17. # variables and parameters
  18. zone_folder = cfg.zone_folder
  19. min_max_filename = cfg.min_max_filename_extension
  20. # define all scenes values
  21. scenes_list = cfg.scenes_names
  22. scenes_indices = cfg.scenes_indices
  23. path = cfg.dataset_path
  24. zones = cfg.zones_indices
  25. seuil_expe_filename = cfg.seuil_expe_filename
  26. max_nb_bits = 8
  27. def display_data_scenes(p_scene, p_bits, p_shifted):
  28. """
  29. @brief Method which generates all .csv files from scenes photos
  30. @param p_scene, scene we want to show values
  31. @param nb_bits, number of bits expected
  32. @param p_shifted, number of bits expected to be shifted
  33. @return nothing
  34. """
  35. scenes = os.listdir(path)
  36. # remove min max file from scenes folder
  37. scenes = [s for s in scenes if min_max_filename not in s]
  38. # go ahead each scenes
  39. for folder_scene in scenes:
  40. if p_scene == folder_scene:
  41. print(folder_scene)
  42. scene_path = os.path.join(path, folder_scene)
  43. # construct each zones folder name
  44. zones_folder = []
  45. # get zones list info
  46. for index in zones:
  47. index_str = str(index)
  48. if len(index_str) < 2:
  49. index_str = "0" + index_str
  50. current_zone = "zone"+index_str
  51. zones_folder.append(current_zone)
  52. zones_images_data = []
  53. threshold_info = []
  54. # get all images of folder
  55. scene_images = sorted([os.path.join(scene_path, img) for img in os.listdir(scene_path) if cfg.scene_image_extension in img])
  56. start_image_path = scene_images[0]
  57. end_image_path = scene_images[-1]
  58. start_quality_image = dt.get_scene_image_quality(scene_images[0])
  59. end_quality_image = dt.get_scene_image_quality(scene_images[-1])
  60. for id_zone, zone_folder in enumerate(zones_folder):
  61. zone_path = os.path.join(scene_path, zone_folder)
  62. # get threshold information
  63. path_seuil = os.path.join(zone_path, seuil_expe_filename)
  64. # open treshold path and get this information
  65. with open(path_seuil, "r") as seuil_file:
  66. threshold_learned = int(seuil_file.readline().strip())
  67. threshold_image_found = False
  68. # for each images
  69. for img_path in scene_images:
  70. current_quality_image = dt.get_scene_image_quality(img_path)
  71. if threshold_learned < int(current_quality_image) and not threshold_image_found:
  72. threshold_image_found = True
  73. threshold_image_path = img_path
  74. threshold_image = dt.get_scene_image_postfix(img_path)
  75. threshold_info.append(threshold_image)
  76. # all indexes of picture to plot
  77. images_path = [start_image_path, threshold_image_path, end_image_path]
  78. images_data = []
  79. for img_path in images_path:
  80. current_img = Image.open(img_path)
  81. img_blocks = segmentation.divide_in_blocks(current_img, (200, 200))
  82. # getting expected block id
  83. block = img_blocks[id_zone]
  84. # get data from mode
  85. # Here you can add the way you compute data
  86. low_bits_block = transform.rgb_to_LAB_L_bits(block, (p_shifted + 1, p_shifted + p_bits + 1))
  87. data = compression.get_SVD_s(low_bits_block)
  88. ##################
  89. # Data mode part #
  90. ##################
  91. # modify data depending mode
  92. data = utils.normalize_arr(data)
  93. images_data.append(data)
  94. zones_images_data.append(images_data)
  95. fig=plt.figure(figsize=(8, 8))
  96. fig.suptitle('Lab SVD ' + str(p_bits) + ' bits shifted by ' + str(p_shifted) + " for " + p_scene + " scene", fontsize=20)
  97. for id, data in enumerate(zones_images_data):
  98. fig.add_subplot(4, 4, (id + 1))
  99. plt.plot(data[0], label='Noisy_' + start_quality_image)
  100. plt.plot(data[1], label='Threshold_' + threshold_info[id])
  101. plt.plot(data[2], label='Reference_' + end_quality_image)
  102. plt.ylabel('Lab SVD ' + str(p_bits) + ' bits shifted by ' + str(p_shifted) + ', ZONE_' + str(id + 1), fontsize=14)
  103. plt.xlabel('Vector features', fontsize=16)
  104. plt.legend(bbox_to_anchor=(0.5, 1), loc=2, borderaxespad=0.2, fontsize=14)
  105. plt.ylim(0, 0.1)
  106. plt.show()
  107. def main():
  108. parser = argparse.ArgumentParser(description="Display curves of shifted bits influence of L canal on specific scene by zone")
  109. parser.add_argument('--scene', type=str, help='scene index to use', choices=scenes_indices)
  110. parser.add_argument('--bits', type=str, help='Number of bits to used')
  111. parser.add_argument('--shifted', type=str, help='Number of bits shifted')
  112. args = parser.parse_args()
  113. p_scene = scenes_list[scenes_indices.index(args.scene)]
  114. p_bits = args.bits
  115. p_shifted = args.shifted
  116. if p_bits + p_shifted > max_nb_bits:
  117. assert False, "Invalid parameters, cannot have bits greater than 8 after shift move"
  118. display_data_scenes(p_scene, p_bits, p_shifted)
  119. if __name__== "__main__":
  120. main()