predict_noisy_image_svd.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from sklearn.externals import joblib
  2. import numpy as np
  3. from ipfml import processing
  4. from PIL import Image
  5. import sys, os, getopt
  6. from modules.utils import config as cfg
  7. from modules.utils import data as dt
  8. path = cfg.dataset_path
  9. min_max_ext = cfg.min_max_filename_extension
  10. metric_choices = cfg.metric_choices_labels
  11. normalization_choices = cfg.normalization_choices
  12. custom_min_max_folder = cfg.min_max_custom_folder
  13. def main():
  14. p_custom = False
  15. if len(sys.argv) <= 1:
  16. print('Run with default parameters...')
  17. print('python predict_noisy_image_svd.py --image path/to/xxxx --interval "0,20" --model path/to/xxxx.joblib --metric lab --mode ["svdn", "svdne"] --custom min_max_file')
  18. sys.exit(2)
  19. try:
  20. opts, args = getopt.getopt(sys.argv[1:], "hi:t:m:m:o:c", ["help=", "image=", "interval=", "model=", "metric=", "mode=", "custom="])
  21. except getopt.GetoptError:
  22. # print help information and exit
  23. print('python predict_noisy_image_svd_lab.py --image path/to/xxxx --interval "xx,xx" --model path/to/xxxx.joblib --metric lab --mode ["svdn", "svdne"] --custom min_max_file')
  24. sys.exit(2)
  25. for o, a in opts:
  26. if o == "-h":
  27. print('python predict_noisy_image_svd_lab.py --image path/to/xxxx --interval "xx,xx" --model path/to/xxxx.joblib --metric lab --mode ["svdn", "svdne"] --custom min_max_file')
  28. sys.exit()
  29. elif o in ("-i", "--image"):
  30. p_img_file = os.path.join(os.path.dirname(__file__), a)
  31. elif o in ("-t", "--interval"):
  32. p_interval = list(map(int, a.split(',')))
  33. elif o in ("-m", "--model"):
  34. p_model_file = os.path.join(os.path.dirname(__file__), a)
  35. elif o in ("-m", "--metric"):
  36. p_metric = a
  37. if not p_metric in metric_choices:
  38. assert False, "Unknow metric choice"
  39. elif o in ("-o", "--mode"):
  40. p_mode = a
  41. if not p_mode in normalization_choices:
  42. assert False, "Mode of normalization not recognized"
  43. elif o in ("-c", "--custom"):
  44. p_custom = a
  45. else:
  46. assert False, "unhandled option"
  47. # load of model file
  48. model = joblib.load(p_model_file)
  49. # load image
  50. img = Image.open(p_img_file)
  51. data = dt.get_svd_data(p_metric, img)
  52. # get interval values
  53. begin, end = p_interval
  54. # check if custom min max file is used
  55. if p_custom:
  56. test_data = data[begin:end]
  57. if p_mode == 'svdne':
  58. # set min_max_filename if custom use
  59. min_max_file_path = custom_min_max_folder + '/' + p_custom
  60. # need to read min_max_file
  61. file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
  62. with open(file_path, 'r') as f:
  63. min_val = float(f.readline().replace('\n', ''))
  64. max_val = float(f.readline().replace('\n', ''))
  65. test_data = processing.normalize_arr_with_range(test_data, min_val, max_val)
  66. if p_mode == 'svdn':
  67. test_data = processing.normalize_arr(test_data)
  68. else:
  69. # check mode to normalize data
  70. if p_mode == 'svdne':
  71. # set min_max_filename if custom use
  72. min_max_file_path = path + '/' + p_metric + min_max_ext
  73. # need to read min_max_file
  74. file_path = os.path.join(os.path.dirname(__file__), min_max_file_path)
  75. with open(file_path, 'r') as f:
  76. min_val = float(f.readline().replace('\n', ''))
  77. max_val = float(f.readline().replace('\n', ''))
  78. l_values = processing.normalize_arr_with_range(data, min_val, max_val)
  79. elif p_mode == 'svdn':
  80. l_values = processing.normalize_arr(data)
  81. else:
  82. l_values = data
  83. test_data = l_values[begin:end]
  84. # get prediction of model
  85. prediction = model.predict([test_data])[0]
  86. # output expected from others scripts
  87. print(prediction)
  88. if __name__== "__main__":
  89. main()