predict_noisy_image_svd.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from sklearn.externals import joblib
  2. import numpy as np
  3. from ipfml import processing, utils
  4. from PIL import Image
  5. import sys, os, argparse, json
  6. from keras.models import model_from_json
  7. from modules.utils import config as cfg
  8. from modules.utils import data as dt
  9. path = cfg.dataset_path
  10. min_max_ext = cfg.min_max_filename_extension
  11. metric_choices = cfg.metric_choices_labels
  12. normalization_choices = cfg.normalization_choices
  13. custom_min_max_folder = cfg.min_max_custom_folder
  14. def main():
  15. # getting all params
  16. parser = argparse.ArgumentParser(description="Script which detects if an image is noisy or not using specific model")
  17. parser.add_argument('--image', type=str, help='Image path')
  18. parser.add_argument('--interval', type=str, help='Interval value to keep from svd', default='"0, 200"')
  19. parser.add_argument('--model', type=str, help='.joblib or .json file (sklearn or keras model)')
  20. parser.add_argument('--mode', type=str, help='Kind of normalization level wished', choices=normalization_choices)
  21. parser.add_argument('--metric', type=str, help='Metric data choice', choices=metric_choices)
  22. parser.add_argument('--custom', type=str, help='Name of custom min max file if use of renormalization of data', default=False)
  23. args = parser.parse_args()
  24. p_img_file = args.image
  25. p_model_file = args.model
  26. p_interval = list(map(int, args.interval.split(',')))
  27. p_mode = args.mode
  28. p_metric = args.metric
  29. p_custom = args.custom
  30. if '.joblib' in p_model_file:
  31. kind_model = 'sklearn'
  32. if '.json' in p_model_file:
  33. kind_model = 'keras'
  34. if 'corr' in p_model_file:
  35. corr_model = True
  36. indices_corr_path = os.path.join(cfg.correlation_indices_folder, p_model_file.split('/')[1].replace('.json', '').replace('.joblib', '') + '.csv')
  37. with open(indices_corr_path, 'r') as f:
  38. data_corr_indices = [int(x) for x in f.readline().split(';') if x != '']
  39. else:
  40. corr_model = False
  41. if kind_model == 'sklearn':
  42. # load of model file
  43. model = joblib.load(p_model_file)
  44. if kind_model == 'keras':
  45. with open(p_model_file, 'r') as f:
  46. json_model = json.load(f)
  47. model = model_from_json(json_model)
  48. model.load_weights(p_model_file.replace('.json', '.h5'))
  49. model.compile(loss='binary_crossentropy',
  50. optimizer='adam',
  51. metrics=['accuracy'])
  52. # load image
  53. img = Image.open(p_img_file)
  54. data = dt.get_svd_data(p_metric, img)
  55. # get interval values
  56. begin, end = p_interval
  57. # check if custom min max file is used
  58. if p_custom:
  59. if corr_model:
  60. test_data = data[data_corr_indices]
  61. else:
  62. test_data = data[begin:end]
  63. if p_mode == 'svdne':
  64. # set min_max_filename if custom use
  65. min_max_file_path = custom_min_max_folder + '/' + p_custom
  66. # need to read min_max_file
  67. file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
  68. with open(file_path, 'r') as f:
  69. min_val = float(f.readline().replace('\n', ''))
  70. max_val = float(f.readline().replace('\n', ''))
  71. test_data = utils.normalize_arr_with_range(test_data, min_val, max_val)
  72. if p_mode == 'svdn':
  73. test_data = utils.normalize_arr(test_data)
  74. else:
  75. # check mode to normalize data
  76. if p_mode == 'svdne':
  77. # set min_max_filename if custom use
  78. min_max_file_path = path + '/' + p_metric + min_max_ext
  79. # need to read min_max_file
  80. file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
  81. with open(file_path, 'r') as f:
  82. min_val = float(f.readline().replace('\n', ''))
  83. max_val = float(f.readline().replace('\n', ''))
  84. l_values = utils.normalize_arr_with_range(data, min_val, max_val)
  85. elif p_mode == 'svdn':
  86. l_values = utils.normalize_arr(data)
  87. else:
  88. l_values = data
  89. if corr_model:
  90. test_data = data[data_corr_indices]
  91. else:
  92. test_data = data[begin:end]
  93. # get prediction of model
  94. if kind_model == 'sklearn':
  95. prediction = model.predict([test_data])[0]
  96. if kind_model == 'keras':
  97. test_data = np.asarray(test_data).reshape(1, len(test_data), 1)
  98. prediction = model.predict_classes([test_data])[0][0]
  99. # output expected from others scripts
  100. print(prediction)
  101. if __name__== "__main__":
  102. main()