estimate_thresholds_lstm.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import sys, os, argparse
  5. # image processing
  6. from PIL import Image
  7. from ipfml import utils
  8. from ipfml.processing import transform, segmentation
  9. import matplotlib.pyplot as plt
  10. # model imports
  11. import joblib
  12. from keras.models import load_model
  13. from keras import backend as K
  14. # modules and config imports
  15. sys.path.insert(0, '') # trick to enable import of main folder module
  16. import custom_config as cfg
  17. import modules.utils.data as dt
  18. from modules.classes.Transformation import Transformation
  19. def write_progress(progress):
  20. barWidth = 180
  21. output_str = "["
  22. pos = barWidth * progress
  23. for i in range(barWidth):
  24. if i < pos:
  25. output_str = output_str + "="
  26. elif i == pos:
  27. output_str = output_str + ">"
  28. else:
  29. output_str = output_str + " "
  30. output_str = output_str + "] " + str(int(progress * 100.0)) + " %\r"
  31. print(output_str)
  32. sys.stdout.write("\033[F")
  33. def main():
  34. parser = argparse.ArgumentParser(description="Read and compute entropy data file")
  35. parser.add_argument('--model', type=str, help='model .h5 file')
  36. parser.add_argument('--folder', type=str,
  37. help='folder where scene dataset is available',
  38. required=True)
  39. parser.add_argument('--features', type=str,
  40. help="list of features choice in order to compute data",
  41. default='svd_reconstruction, ipca_reconstruction',
  42. required=True)
  43. parser.add_argument('--params', type=str,
  44. help="list of specific param for each feature choice (See README.md for further information in 3D mode)",
  45. default='100, 200 :: 50, 25',
  46. required=True)
  47. parser.add_argument('--size', type=str,
  48. help="specific size of image",
  49. default='100, 100',
  50. required=True)
  51. parser.add_argument('--sequence', type=int, help='sequence size expected', required=True, default=1)
  52. parser.add_argument('--n_stop', type=int, help='number of detection to make sure to stop', default=1)
  53. parser.add_argument('--save', type=str, help='filename where to save input data')
  54. parser.add_argument('--label', type=str, help='label to use when saving thresholds')
  55. args = parser.parse_args()
  56. p_model = args.model
  57. p_folder = args.folder
  58. p_features = list(map(str.strip, args.features.split(',')))
  59. p_params = list(map(str.strip, args.params.split('::')))
  60. p_size = args.size
  61. p_sequence = args.sequence
  62. p_n_stop = args.n_stop
  63. p_save = args.save
  64. p_label = args.label
  65. # 1. Load expected transformations
  66. # list of transformations
  67. transformations = []
  68. for id, feature in enumerate(p_features):
  69. if feature not in cfg.features_choices_labels or feature == 'static':
  70. raise ValueError("Unknown feature, please select a correct feature (`static` excluded) : ", cfg.features_choices_labels)
  71. transformations.append(Transformation(feature, p_params[id], p_size))
  72. # 2. load model and compile it
  73. # TODO : check kind of model
  74. model = joblib.load(p_model)
  75. model.compile(loss='binary_crossentropy',
  76. optimizer='rmsprop',
  77. metrics=['accuracy'])
  78. # model = load_model(p_model)
  79. # model.compile(loss='binary_crossentropy',
  80. # optimizer='rmsprop',
  81. # metrics=['accuracy'])
  82. estimated_thresholds = []
  83. n_estimated_thresholds = []
  84. sequence_list_zones = []
  85. scene_path = p_folder
  86. if not os.path.exists(scene_path):
  87. print('Unvalid scene path:', scene_path)
  88. exit(0)
  89. # 3. retrieve human_thresholds
  90. # construct zones folder
  91. zones_indices = np.arange(16)
  92. zones_list = []
  93. for index in zones_indices:
  94. index_str = str(index)
  95. while len(index_str) < 2:
  96. index_str = "0" + index_str
  97. zones_list.append(cfg.zone_folder + index_str)
  98. # 4. get estimated thresholds using model and specific method
  99. images_path = sorted([os.path.join(scene_path, img) for img in os.listdir(scene_path) if cfg.scene_image_extension in img])
  100. number_of_images = len(images_path)
  101. image_indices = [ dt.get_scene_image_quality(img_path) for img_path in images_path ]
  102. image_counter = 0
  103. # append empty list
  104. for _ in zones_list:
  105. estimated_thresholds.append(None)
  106. n_estimated_thresholds.append(0)
  107. sequence_list_zones.append([])
  108. for img_i, img_path in enumerate(images_path):
  109. blocks = segmentation.divide_in_blocks(Image.open(img_path), (200, 200))
  110. for index, block in enumerate(blocks):
  111. sequence_list = sequence_list_zones[index]
  112. if estimated_thresholds[index] is None:
  113. transformed_list = []
  114. # compute data here
  115. for transformation in transformations:
  116. transformed = transformation.getTransformedImage(block)
  117. transformed_list.append(transformed)
  118. data = np.array(transformed_list)
  119. sequence_list.append(data)
  120. if len(sequence_list) >= p_sequence:
  121. # compute input size
  122. # n_chanels, _, _ = data.shape
  123. input_data = np.array(sequence_list)
  124. input_data = np.expand_dims(input_data, axis=0)
  125. prob = model.predict(np.array(input_data))[0]
  126. #print(index, ':', image_indices[img_i], '=>', prediction)
  127. # if prob is now near to label `0` then image is not longer noisy
  128. if prob < 0.5:
  129. n_estimated_thresholds[index] += 1
  130. # if same number of detection is attempted
  131. if n_estimated_thresholds[index] >= p_n_stop:
  132. estimated_thresholds[index] = image_indices[img_i]
  133. else:
  134. n_estimated_thresholds[index] = 0
  135. # remove first image
  136. del sequence_list[0]
  137. # write progress bar
  138. write_progress((image_counter + 1) / number_of_images)
  139. image_counter = image_counter + 1
  140. # default label
  141. for i, _ in enumerate(zones_list):
  142. if estimated_thresholds[i] == None:
  143. estimated_thresholds[i] = image_indices[-1]
  144. # 6. save estimated thresholds into specific file
  145. print('\nEstimated thresholds', estimated_thresholds)
  146. if p_save is not None:
  147. with open(p_save, 'a') as f:
  148. f.write(p_label + ';')
  149. for t in estimated_thresholds:
  150. f.write(str(t) + ';')
  151. f.write('\n')
  152. if __name__== "__main__":
  153. main()