predict_noisy_image_svd.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from sklearn.externals import joblib
  2. import numpy as np
  3. from ipfml import processing, utils
  4. from PIL import Image
  5. import sys, os, argparse
  6. from modules.utils import config as cfg
  7. from modules.utils import data as dt
  8. path = cfg.dataset_path
  9. min_max_ext = cfg.min_max_filename_extension
  10. metric_choices = cfg.metric_choices_labels
  11. normalization_choices = cfg.normalization_choices
  12. custom_min_max_folder = cfg.min_max_custom_folder
  13. def main():
  14. # getting all params
  15. parser = argparse.ArgumentParser(description="Script which detects if an image is noisy or not using specific model")
  16. parser.add_argument('--image', type=str, help='Image path')
  17. parser.add_argument('--interval', type=str, help='Interval value to keep from svd', default='"0, 200"')
  18. parser.add_argument('--model', type=str, help='.joblib or .json file (sklearn or keras model)')
  19. parser.add_argument('--mode', type=str, help='Kind of normalization level wished', choices=normalization_choices)
  20. parser.add_argument('--metric', type=str, help='Metric data choice', choices=metric_choices)
  21. parser.add_argument('--custom', type=str, help='Name of custom min max file if use of renormalization of data', default=False)
  22. args = parser.parse_args()
  23. p_img_file = args.image
  24. p_model_file = args.model
  25. p_interval = list(map(int, args.interval.split(',')))
  26. p_mode = args.mode
  27. p_metric = args.metric
  28. p_custom = args.custom
  29. if '.joblib' in p_model_file:
  30. kind_model = 'sklearn'
  31. if '.json' in p_model_file:
  32. kind_model = 'keras'
  33. if 'corr' in p_model_file:
  34. corr_model = True
  35. indices_corr_path = os.path.join(cfg.correlation_indices_folder, p_model_file.split('.')[0] + '.csv')
  36. with open(indices_corr_path, 'r') as f:
  37. data_corr_indices = [int(x) for x in f.readline().split(';') if x != '']
  38. else:
  39. corr_model = False
  40. if kind_model == 'sklearn':
  41. # load of model file
  42. model = joblib.load(p_model_file)
  43. if kind_model == 'keras':
  44. with open(p_model_file, 'r') as f:
  45. json_model = json.load(f)
  46. model = model_from_json(json_model)
  47. model.load_weights(p_model_file.replace('.json', '.h5'))
  48. model.compile(loss='binary_crossentropy',
  49. optimizer='adam',
  50. metrics=['accuracy'])
  51. # load image
  52. img = Image.open(p_img_file)
  53. data = dt.get_svd_data(p_metric, img)
  54. # get interval values
  55. begin, end = p_interval
  56. # check if custom min max file is used
  57. if p_custom:
  58. if corr_model:
  59. test_data = data[data_corr_indices]
  60. else:
  61. test_data = data[begin:end]
  62. if p_mode == 'svdne':
  63. # set min_max_filename if custom use
  64. min_max_file_path = custom_min_max_folder + '/' + p_custom
  65. # need to read min_max_file
  66. file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
  67. with open(file_path, 'r') as f:
  68. min_val = float(f.readline().replace('\n', ''))
  69. max_val = float(f.readline().replace('\n', ''))
  70. test_data = utils.normalize_arr_with_range(test_data, min_val, max_val)
  71. if p_mode == 'svdn':
  72. test_data = utils.normalize_arr(test_data)
  73. else:
  74. # check mode to normalize data
  75. if p_mode == 'svdne':
  76. # set min_max_filename if custom use
  77. min_max_file_path = path + '/' + p_metric + min_max_ext
  78. # need to read min_max_file
  79. file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
  80. with open(file_path, 'r') as f:
  81. min_val = float(f.readline().replace('\n', ''))
  82. max_val = float(f.readline().replace('\n', ''))
  83. l_values = utils.normalize_arr_with_range(data, min_val, max_val)
  84. elif p_mode == 'svdn':
  85. l_values = utils.normalize_arr(data)
  86. else:
  87. l_values = data
  88. if corr_model:
  89. test_data = data[data_corr_indices]
  90. else:
  91. test_data = data[begin:end]
  92. # get prediction of model
  93. if kind_model == 'sklearn':
  94. prediction = model.predict([test_data])[0]
  95. if kind_model == 'keras':
  96. test_data = test_data.reshape(len(test_data), 1)
  97. prediction = model.predict_classes([test_data])[0]
  98. # output expected from others scripts
  99. print(prediction)
  100. if __name__== "__main__":
  101. main()