display_scenes_zones_shifted.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Sep 14 21:02:42 2018
  5. @author: jbuisine
  6. """
  7. from __future__ import print_function
  8. import sys, os, argparse
  9. import numpy as np
  10. import random
  11. import time
  12. import json
  13. from PIL import Image
  14. from ipfml import processing, metrics, utils
  15. from skimage import color
  16. import matplotlib.pyplot as plt
  17. from modules.utils import config as cfg
  18. config_filename = cfg.config_filename
  19. zone_folder = cfg.zone_folder
  20. min_max_filename = cfg.min_max_filename_extension
  21. # define all scenes values
  22. scenes_list = cfg.scenes_names
  23. scenes_indices = cfg.scenes_indices
  24. choices = cfg.normalization_choices
  25. path = cfg.dataset_path
  26. zones = cfg.zones_indices
  27. seuil_expe_filename = cfg.seuil_expe_filename
  28. metric_choices = cfg.metric_choices_labels
  29. max_nb_bits = 8
  30. def display_data_scenes(p_scene, p_bits, p_shifted):
  31. """
  32. @brief Method which generates all .csv files from scenes photos
  33. @param p_scene, scene we want to show values
  34. @param nb_bits, number of bits expected
  35. @param p_shifted, number of bits expected to be shifted
  36. @return nothing
  37. """
  38. scenes = os.listdir(path)
  39. # remove min max file from scenes folder
  40. scenes = [s for s in scenes if min_max_filename not in s]
  41. # go ahead each scenes
  42. for id_scene, folder_scene in enumerate(scenes):
  43. if p_scene == folder_scene:
  44. print(folder_scene)
  45. scene_path = os.path.join(path, folder_scene)
  46. config_file_path = os.path.join(scene_path, config_filename)
  47. with open(config_file_path, "r") as config_file:
  48. last_image_name = config_file.readline().strip()
  49. prefix_image_name = config_file.readline().strip()
  50. start_index_image = config_file.readline().strip()
  51. end_index_image = config_file.readline().strip()
  52. step_counter = int(config_file.readline().strip())
  53. # construct each zones folder name
  54. zones_folder = []
  55. # get zones list info
  56. for index in zones:
  57. index_str = str(index)
  58. if len(index_str) < 2:
  59. index_str = "0" + index_str
  60. current_zone = "zone"+index_str
  61. zones_folder.append(current_zone)
  62. zones_images_data = []
  63. threshold_info = []
  64. for id_zone, zone_folder in enumerate(zones_folder):
  65. zone_path = os.path.join(scene_path, zone_folder)
  66. current_counter_index = int(start_index_image)
  67. end_counter_index = int(end_index_image)
  68. # get threshold information
  69. path_seuil = os.path.join(zone_path, seuil_expe_filename)
  70. # open treshold path and get this information
  71. with open(path_seuil, "r") as seuil_file:
  72. seuil_learned = int(seuil_file.readline().strip())
  73. threshold_image_found = False
  74. while(current_counter_index <= end_counter_index and not threshold_image_found):
  75. if seuil_learned < int(current_counter_index):
  76. current_counter_index_str = str(current_counter_index)
  77. while len(start_index_image) > len(current_counter_index_str):
  78. current_counter_index_str = "0" + current_counter_index_str
  79. threshold_image_found = True
  80. threshold_image_zone = current_counter_index_str
  81. threshold_info.append(threshold_image_zone)
  82. current_counter_index += step_counter
  83. # all indexes of picture to plot
  84. images_indexes = [start_index_image, threshold_image_zone, end_index_image]
  85. images_data = []
  86. print(images_indexes)
  87. for index in images_indexes:
  88. img_path = os.path.join(scene_path, prefix_image_name + index + ".png")
  89. current_img = Image.open(img_path)
  90. img_blocks = processing.divide_in_blocks(current_img, (200, 200))
  91. # getting expected block id
  92. block = img_blocks[id_zone]
  93. # get data from mode
  94. # Here you can add the way you compute data
  95. low_bits_block = processing.rgb_to_LAB_L_bits(block, (p_shifted + 1, p_shifted + p_bits + 1))
  96. data = metrics.get_SVD_s(low_bits_block)
  97. ##################
  98. # Data mode part #
  99. ##################
  100. # modify data depending mode
  101. data = utils.normalize_arr(data)
  102. images_data.append(data)
  103. zones_images_data.append(images_data)
  104. fig=plt.figure(figsize=(8, 8))
  105. fig.suptitle('Lab SVD ' + str(p_bits) + ' bits shifted by ' + str(p_shifted) + " for " + p_scene + " scene", fontsize=20)
  106. for id, data in enumerate(zones_images_data):
  107. fig.add_subplot(4, 4, (id + 1))
  108. plt.plot(data[0], label='Noisy_' + start_index_image)
  109. plt.plot(data[1], label='Threshold_' + threshold_info[id])
  110. plt.plot(data[2], label='Reference_' + end_index_image)
  111. plt.ylabel('Lab SVD ' + str(p_bits) + ' bits shifted by ' + str(p_shifted) + ', ZONE_' + str(id + 1), fontsize=14)
  112. plt.xlabel('Vector features', fontsize=16)
  113. plt.legend(bbox_to_anchor=(0.5, 1), loc=2, borderaxespad=0.2, fontsize=14)
  114. plt.ylim(0, 0.1)
  115. plt.show()
  116. def main():
  117. parser = argparse.ArgumentParser(description="Display curves of shifted bits influence of L canal on specific scene by zone")
  118. parser.add_argument('--scene', type=str, help='scene index to use', choices=scenes_indices)
  119. parser.add_argument('--bits', type=str, help='Number of bits to used')
  120. parser.add_argument('--shifted', type=str, help='Number of bits shifted')
  121. args = parser.parse_args()
  122. p_scene = scenes_list[scenes_indices.index(args.scene)]
  123. p_bits = args.bits
  124. p_shifted = args.shifted
  125. if p_bits + p_shifted > max_nb_bits:
  126. assert False, "Invalid parameters, cannot have bits greater than 8 after shift move"
  127. display_data_scenes(p_scene, p_bits, p_shifted)
  128. if __name__== "__main__":
  129. main()