display_bits_shifted_scene.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Sep 14 21:02:42 2018
  5. @author: jbuisine
  6. """
  7. from __future__ import print_function
  8. import sys, os, argparse
  9. import numpy as np
  10. import random
  11. import time
  12. import json
  13. from PIL import Image
  14. from ipfml import processing
  15. from ipfml import metrics
  16. from skimage import color
  17. import matplotlib.pyplot as plt
  18. from modules.utils import config as cfg
  19. config_filename = cfg.config_filename
  20. zone_folder = cfg.zone_folder
  21. min_max_filename = cfg.min_max_filename_extension
  22. # define all scenes values
  23. scenes_list = cfg.scenes_names
  24. scenes_indices = cfg.scenes_indices
  25. choices = cfg.normalization_choices
  26. path = cfg.dataset_path
  27. zones = cfg.zones_indices
  28. seuil_expe_filename = cfg.seuil_expe_filename
  29. metric_choices = cfg.metric_choices_labels
  30. max_nb_bits = 8
  31. def display_data_scenes(nb_bits, p_scene):
  32. """
  33. @brief Method display shifted values for specific scene
  34. @param nb_bits, number of bits expected
  35. @param p_scene, scene we want to show values
  36. @return nothing
  37. """
  38. scenes = os.listdir(path)
  39. # remove min max file from scenes folder
  40. scenes = [s for s in scenes if min_max_filename not in s]
  41. # go ahead each scenes
  42. for id_scene, folder_scene in enumerate(scenes):
  43. if p_scene == folder_scene:
  44. print(folder_scene)
  45. scene_path = os.path.join(path, folder_scene)
  46. config_file_path = os.path.join(scene_path, config_filename)
  47. with open(config_file_path, "r") as config_file:
  48. last_image_name = config_file.readline().strip()
  49. prefix_image_name = config_file.readline().strip()
  50. start_index_image = config_file.readline().strip()
  51. end_index_image = config_file.readline().strip()
  52. step_counter = int(config_file.readline().strip())
  53. # construct each zones folder name
  54. zones_folder = []
  55. # get zones list info
  56. for index in zones:
  57. index_str = str(index)
  58. if len(index_str) < 2:
  59. index_str = "0" + index_str
  60. current_zone = "zone"+index_str
  61. zones_folder.append(current_zone)
  62. zones_images_data = []
  63. threshold_info = []
  64. for id_zone, zone_folder in enumerate(zones_folder):
  65. zone_path = os.path.join(scene_path, zone_folder)
  66. current_counter_index = int(start_index_image)
  67. end_counter_index = int(end_index_image)
  68. # get threshold information
  69. path_seuil = os.path.join(zone_path, seuil_expe_filename)
  70. # open treshold path and get this information
  71. with open(path_seuil, "r") as seuil_file:
  72. seuil_learned = int(seuil_file.readline().strip())
  73. threshold_info.append(seuil_learned)
  74. # compute mean threshold values
  75. mean_threshold = sum(threshold_info) / float(len(threshold_info))
  76. print(mean_threshold, "mean threshold found")
  77. threshold_image_found = False
  78. # find appropriate mean threshold picture
  79. while(current_counter_index <= end_counter_index and not threshold_image_found):
  80. if mean_threshold < int(current_counter_index):
  81. current_counter_index_str = str(current_counter_index)
  82. while len(start_index_image) > len(current_counter_index_str):
  83. current_counter_index_str = "0" + current_counter_index_str
  84. threshold_image_found = True
  85. threshold_image_zone = current_counter_index_str
  86. current_counter_index += step_counter
  87. # all indexes of picture to plot
  88. images_indexes = [start_index_image, threshold_image_zone, end_index_image]
  89. images_data = []
  90. print(images_indexes)
  91. low_bits_svd_values = []
  92. for i in range(0, max_nb_bits - nb_bits + 1):
  93. low_bits_svd_values.append([])
  94. for index in images_indexes:
  95. img_path = os.path.join(scene_path, prefix_image_name + index + ".png")
  96. current_img = Image.open(img_path)
  97. block_used = np.array(current_img)
  98. low_bits_block = processing.rgb_to_LAB_L_bits(block_used, (i + 1, i + nb_bits + 1))
  99. low_bits_svd = metrics.get_SVD_s(low_bits_block)
  100. low_bits_svd = [b / low_bits_svd[0] for b in low_bits_svd]
  101. low_bits_svd_values[i].append(low_bits_svd)
  102. fig=plt.figure(figsize=(8, 8))
  103. fig.suptitle("Lab SVD " + str(nb_bits) + " bits values shifted for " + p_scene + " scene", fontsize=20)
  104. for id, data in enumerate(low_bits_svd_values):
  105. fig.add_subplot(3, 3, (id + 1))
  106. plt.plot(data[0], label='Noisy_' + start_index_image)
  107. plt.plot(data[1], label='Threshold_' + threshold_image_zone)
  108. plt.plot(data[2], label='Reference_' + end_index_image)
  109. plt.ylabel('Lab SVD ' + str(nb_bits) + ' bits values shifted by ' + str(id), fontsize=14)
  110. plt.xlabel('Vector features', fontsize=16)
  111. plt.legend(bbox_to_anchor=(0.5, 1), loc=2, borderaxespad=0.2, fontsize=14)
  112. plt.ylim(0, 0.1)
  113. plt.show()
  114. def main():
  115. parser = argparse.ArgumentParser(description="Display curves of shifted bits influence of L canal on specific scene")
  116. parser.add_argument('--bits', type=str, help='Number of bits to display')
  117. parser.add_argument('--scene', type=str, help="scene index to use", choices=scenes_indices)
  118. args = parser.parse_args()
  119. p_bits = args.bits
  120. p_scene = scenes_list[scenes_indices.index(args.scene)]
  121. display_data_scenes(p_bits, p_scene)
  122. if __name__== "__main__":
  123. main()