write_result_keras.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import numpy as np
  2. import pandas as pd
  3. import json
  4. import os, sys, argparse, subprocess
  5. from keras.models import model_from_json
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.preprocessing import MinMaxScaler
  8. import modules.config as cfg
  9. import modules.metrics as metrics
  10. from joblib import dump, load
  11. from PIL import Image
  12. import ipfml.iqa.fr as fr
  13. n_samples_image_name_postfix = "_samples_mean.png"
  14. reference_image_name_postfix = "_1000_samples_mean.png"
  15. def write_result(_scene_name, _data_file, _model_path, _n, _reconstructed_path, _iqa):
  16. # prepare data to get score information
  17. dataset=np.loadtxt(_data_file, delimiter=';')
  18. y = dataset[:,0]
  19. X = dataset[:,1:]
  20. y=np.reshape(y, (-1,1))
  21. scaler = MinMaxScaler()
  22. scaler.fit(X)
  23. scaler.fit(y)
  24. xscale=scaler.transform(X)
  25. yscale=scaler.transform(y)
  26. _, X_test, _, y_test = train_test_split(xscale, yscale)
  27. # prepare image path to compare
  28. n_samples_image_path = os.path.join(cfg.reconstructed_folder, _scene_name + '_' + _n + n_samples_image_name_postfix)
  29. reference_image_path = os.path.join(cfg.reconstructed_folder, _scene_name + reference_image_name_postfix)
  30. if not os.path.exists(n_samples_image_path):
  31. # call sub process to create 'n' samples img
  32. print("Creation of 'n' samples image : ", n_samples_image_path)
  33. subprocess.run(["python", "reconstruct_scene_mean.py", "--scene", _scene_name, "--n", _n, "--image_name", n_samples_image_path.split('/')[-1]])
  34. if not os.path.exists(reference_image_path):
  35. # call sub process to create 'reference' img
  36. print("Creation of reference image : ", reference_image_path)
  37. subprocess.run(["python", "reconstruct_scene_mean.py", "--scene", _scene_name, "--n", str(1000), "--image_name", reference_image_path.split('/')[-1]])
  38. # load the trained model
  39. with open(_model_path, 'r') as f:
  40. json_model = json.load(f)
  41. model = model_from_json(json_model)
  42. model.load_weights(_model_path.replace('.json', '.h5'))
  43. model.compile(loss='binary_crossentropy',
  44. optimizer='adam',
  45. metrics=['accuracy'])
  46. # get coefficient of determination score on test set
  47. y_predicted = model.predict(X_test)
  48. len_shape, _ = y_predicted.shape
  49. y_predicted = y_predicted.reshape(len_shape)
  50. coeff = metrics.coefficient_of_determination(y_test, y_predicted)
  51. # Get data information
  52. reference_image = Image.open(reference_image_path)
  53. reconstructed_image = Image.open(_reconstructed_path)
  54. n_samples_image = Image.open(n_samples_image_path)
  55. # Load expected IQA comparison
  56. try:
  57. fr_iqa = getattr(fr, _iqa)
  58. except AttributeError:
  59. raise NotImplementedError("FR IQA `{}` not implement `{}`".format(fr.__name__, _iqa))
  60. mse_ref_reconstructed_samples = fr_iqa(reference_image, reconstructed_image)
  61. mse_reconstructed_n_samples = fr_iqa(n_samples_image, reconstructed_image)
  62. model_name = _model_path.split('/')[-1].replace('.json', '')
  63. # save score into models_comparisons_keras.csv file
  64. with open(cfg.global_result_filepath_keras, "a") as f:
  65. f.write(model_name + ';' + str(len(y)) + ';' + str(coeff[0]) + ';' + str(mse_reconstructed_n_samples) + ';' + str(mse_ref_reconstructed_samples) + '\n')
  66. def main():
  67. parser = argparse.ArgumentParser(description="Train model and saved it")
  68. parser.add_argument('--scene', type=str, help='Scene name to reconstruct', choices=cfg.scenes_list)
  69. parser.add_argument('--data', type=str, help='Filename of dataset')
  70. parser.add_argument('--model_path', type=str, help='Json model file path')
  71. parser.add_argument('--n', type=str, help='Number of pixel values approximated to keep')
  72. parser.add_argument('--image_path', type=str, help="The image reconstructed to compare with")
  73. parser.add_argument('--iqa', type=str, help='Image to compare', choices=['ssim', 'mse', 'rmse', 'mae', 'psnr'])
  74. args = parser.parse_args()
  75. param_scene_name = args.scene
  76. param_data_file = args.data
  77. param_n = args.n
  78. param_model_path = args.model_path
  79. param_image_path = args.image_path
  80. param_iqa = args.iqa
  81. write_result(param_scene_name, param_data_file, param_model_path, param_n, param_image_path, param_iqa)
  82. if __name__== "__main__":
  83. main()