display_reconstructed_image_from_simulation.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import math
  5. import time
  6. import os, sys, argparse
  7. # image processing imports
  8. import matplotlib.pyplot as plt
  9. from PIL import Image
  10. # modules imports
  11. sys.path.insert(0, '') # trick to enable import of main folder module
  12. import custom_config as cfg
  13. from data_attributes import get_image_features
  14. # other variables
  15. learned_zones_folder = cfg.learned_zones_folder
  16. models_name = cfg.models_names_list
  17. # utils information
  18. zone_width, zone_height = (200, 200)
  19. scene_width, scene_height = (800, 800)
  20. nb_x_parts = math.floor(scene_width / zone_width)
  21. def reconstruct_image(folder_path, model_name, p_limit):
  22. """
  23. @brief Method used to display simulation given .csv files
  24. @param folder_path, folder which contains all .csv files obtained during simulation
  25. @param model_name, current name of model
  26. @return nothing
  27. """
  28. for name in models_name:
  29. if name in model_name:
  30. data_filename = model_name
  31. learned_zones_folder_path = os.path.join(learned_zones_folder, data_filename)
  32. data_files = [x for x in os.listdir(folder_path) if '.png' not in x]
  33. scene_names = [f.split('_')[3] for f in data_files]
  34. # compute zone start index
  35. zones_coordinates = []
  36. for index, zone_index in enumerate(cfg.zones_indices):
  37. x_zone = (zone_index % nb_x_parts) * zone_width
  38. y_zone = (math.floor(zone_index / nb_x_parts)) * zone_height
  39. zones_coordinates.append((x_zone, y_zone))
  40. print(zones_coordinates)
  41. for id, f in enumerate(data_files):
  42. scene_name = scene_names[id]
  43. path_file = os.path.join(folder_path, f)
  44. # TODO : check if necessary to keep information about zone learned when displaying data
  45. scenes_zones_used_file_path = os.path.join(learned_zones_folder_path, scene_name + '.csv')
  46. zones_used = []
  47. if os.path.exists(scenes_zones_used_file_path):
  48. with open(scenes_zones_used_file_path, 'r') as f:
  49. zones_used = [int(x) for x in f.readline().split(';') if x != '']
  50. # 1. find estimated threshold for each zone scene using `data_files` and p_limit
  51. model_thresholds = []
  52. df = pd.read_csv(path_file, header=None, sep=";")
  53. for index, row in df.iterrows():
  54. row = np.asarray(row)
  55. #threshold = row[2]
  56. start_index = row[3]
  57. step_value = row[4]
  58. rendering_predictions = row[5:]
  59. nb_generated_image = 0
  60. nb_not_noisy_prediction = 0
  61. for prediction in rendering_predictions:
  62. if int(prediction) == 0:
  63. nb_not_noisy_prediction += 1
  64. else:
  65. nb_not_noisy_prediction = 0
  66. # exit loop if limit is targeted
  67. if nb_not_noisy_prediction >= p_limit:
  68. break
  69. nb_generated_image += 1
  70. current_threshold = start_index + step_value * nb_generated_image
  71. model_thresholds.append(current_threshold)
  72. # 2. find images for each zone which are attached to this estimated threshold by the model
  73. zone_images_index = []
  74. for est_threshold in model_thresholds:
  75. str_index = str(est_threshold)
  76. while len(str_index) < 5:
  77. str_index = "0" + str_index
  78. zone_images_index.append(str_index)
  79. scene_folder = os.path.join(cfg.dataset_path, scene_name)
  80. scenes_images = [img for img in os.listdir(scene_folder) if cfg.scene_image_extension in img]
  81. scenes_images = sorted(scenes_images)
  82. images_zones = []
  83. line_images_zones = []
  84. # get image using threshold by zone
  85. for id, zone_index in enumerate(zone_images_index):
  86. filtered_images = [img for img in scenes_images if zone_index in img]
  87. if len(filtered_images) > 0:
  88. image_name = filtered_images[0]
  89. else:
  90. image_name = scenes_images[-1]
  91. #print(image_name)
  92. image_path = os.path.join(scene_folder, image_name)
  93. selected_image = Image.open(image_path)
  94. x_zone, y_zone = zones_coordinates[id]
  95. zone_image = np.array(selected_image)[y_zone:y_zone+zone_height, x_zone:x_zone+zone_width]
  96. line_images_zones.append(zone_image)
  97. if int(id + 1) % int(scene_width / zone_width) == 0:
  98. images_zones.append(np.concatenate(line_images_zones, axis=1))
  99. print(len(line_images_zones))
  100. line_images_zones = []
  101. # 3. reconstructed the image using these zones
  102. reconstructed_image = np.concatenate(images_zones, axis=0)
  103. # 4. Save the image with generated name based on scene, model and `p_limit`
  104. reconstructed_pil_img = Image.fromarray(reconstructed_image)
  105. output_path = os.path.join(folder_path, scene_names[id] + '_reconstruction_limit_' + str(p_limit) + '.png')
  106. reconstructed_pil_img.save(output_path)
  107. def main():
  108. parser = argparse.ArgumentParser(description="Display simulations curves from simulation data")
  109. parser.add_argument('--folder', type=str, help='Folder which contains simulations data for scenes')
  110. parser.add_argument('--model', type=str, help='Name of the model used for simulations')
  111. parser.add_argument('--limit', type=int, help='Detection limit to target to stop rendering (number of times model tells image has not more noise)')
  112. args = parser.parse_args()
  113. p_folder = args.folder
  114. p_limit = args.limit
  115. if args.model:
  116. p_model = args.model
  117. else:
  118. # find p_model from folder if model arg not given (folder path need to have model name)
  119. if p_folder.split('/')[-1]:
  120. p_model = p_folder.split('/')[-1]
  121. else:
  122. p_model = p_folder.split('/')[-2]
  123. print(p_model)
  124. reconstruct_image(p_folder, p_model, p_limit)
  125. print(p_folder)
  126. if __name__== "__main__":
  127. main()