write_result_keras.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import json
  5. import os, sys, argparse, subprocess
  6. # model imports
  7. from keras.models import model_from_json
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.preprocessing import MinMaxScaler
  10. from joblib import dump, load
  11. # image processing imports
  12. from PIL import Image
  13. import ipfml.iqa.fr as fr
  14. from ipfml import metrics
  15. # modules and config imports
  16. sys.path.insert(0, '') # trick to enable import of main folder module
  17. import custom_config as cfg
  18. n_samples_image_name_postfix = "_samples_mean.png"
  19. reference_image_name_postfix = "_1000_samples_mean.png"
  20. def write_result(_scene_name, _data_file, _model_path, _n, _reconstructed_path, _iqa):
  21. # prepare data to get score information
  22. dataset=np.loadtxt(_data_file, delimiter=';')
  23. y = dataset[:,0]
  24. X = dataset[:,1:]
  25. y=np.reshape(y, (-1,1))
  26. scaler = MinMaxScaler()
  27. scaler.fit(X)
  28. scaler.fit(y)
  29. xscale=scaler.transform(X)
  30. yscale=scaler.transform(y)
  31. _, X_test, _, y_test = train_test_split(xscale, yscale)
  32. # prepare image path to compare
  33. n_samples_image_path = os.path.join(cfg.reconstructed_folder, _scene_name + '_' + _n + n_samples_image_name_postfix)
  34. reference_image_path = os.path.join(cfg.reconstructed_folder, _scene_name + reference_image_name_postfix)
  35. if not os.path.exists(n_samples_image_path):
  36. # call sub process to create 'n' samples img
  37. print("Creation of 'n' samples image : ", n_samples_image_path)
  38. subprocess.run(["python", "reconstruct/reconstruct_scene_mean.py", "--scene", _scene_name, "--n", _n, "--image_name", n_samples_image_path.split('/')[-1]])
  39. if not os.path.exists(reference_image_path):
  40. # call sub process to create 'reference' img
  41. print("Creation of reference image : ", reference_image_path)
  42. subprocess.run(["python", "reconstruct/reconstruct_scene_mean.py", "--scene", _scene_name, "--n", str(1000), "--image_name", reference_image_path.split('/')[-1]])
  43. # load the trained model
  44. with open(_model_path, 'r') as f:
  45. json_model = json.load(f)
  46. model = model_from_json(json_model)
  47. model.load_weights(_model_path.replace('.json', '.h5'))
  48. model.compile(loss='binary_crossentropy',
  49. optimizer='adam',
  50. metrics=['accuracy'])
  51. # get coefficient of determination score on test set
  52. y_predicted = model.predict(X_test)
  53. len_shape, _ = y_predicted.shape
  54. y_predicted = y_predicted.reshape(len_shape)
  55. coeff = metrics.coefficient_of_determination(y_test, y_predicted)
  56. # Get data information
  57. reference_image = Image.open(reference_image_path)
  58. reconstructed_image = Image.open(_reconstructed_path)
  59. n_samples_image = Image.open(n_samples_image_path)
  60. # Load expected IQA comparison
  61. try:
  62. fr_iqa = getattr(fr, _iqa)
  63. except AttributeError:
  64. raise NotImplementedError("FR IQA `{}` not implement `{}`".format(fr.__name__, _iqa))
  65. mse_ref_reconstructed_samples = fr_iqa(reference_image, reconstructed_image)
  66. mse_reconstructed_n_samples = fr_iqa(n_samples_image, reconstructed_image)
  67. model_name = _model_path.split('/')[-1].replace('.json', '')
  68. if not os.path.exists(cfg.results_information_folder):
  69. os.makedirs(cfg.results_information_folder)
  70. # save score into models_comparisons_keras.csv file
  71. with open(cfg.global_result_filepath_keras, "a") as f:
  72. f.write(model_name + ';' + str(len(y)) + ';' + str(coeff[0]) + ';' + str(mse_reconstructed_n_samples) + ';' + str(mse_ref_reconstructed_samples) + '\n')
  73. def main():
  74. parser = argparse.ArgumentParser(description="Train model and saved it")
  75. parser.add_argument('--scene', type=str, help='Scene name to reconstruct', choices=cfg.scenes_list)
  76. parser.add_argument('--data', type=str, help='Filename of dataset')
  77. parser.add_argument('--model_path', type=str, help='Json model file path')
  78. parser.add_argument('--n', type=str, help='Number of pixel values approximated to keep')
  79. parser.add_argument('--image_path', type=str, help="The image reconstructed to compare with")
  80. parser.add_argument('--iqa', type=str, help='Image to compare', choices=['ssim', 'mse', 'rmse', 'mae', 'psnr'])
  81. args = parser.parse_args()
  82. param_scene_name = args.scene
  83. param_data_file = args.data
  84. param_n = args.n
  85. param_model_path = args.model_path
  86. param_image_path = args.image_path
  87. param_iqa = args.iqa
  88. write_result(param_scene_name, param_data_file, param_model_path, param_n, param_image_path, param_iqa)
  89. if __name__== "__main__":
  90. main()