prediction_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import sys, os, argparse
  5. import json
  6. # model imports
  7. import cnn_models as models
  8. import tensorflow as tf
  9. import keras
  10. from keras import backend as K
  11. from keras.callbacks import ModelCheckpoint
  12. from keras.models import model_from_json
  13. from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
  14. # image processing imports
  15. import cv2
  16. from sklearn.utils import shuffle
  17. # config imports
  18. sys.path.insert(0, '') # trick to enable import of main folder module
  19. import custom_config as cfg
  20. def main():
  21. parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
  22. parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
  23. parser.add_argument('--model', type=str, help='.json file of keras model')
  24. args = parser.parse_args()
  25. p_data_file = args.data
  26. p_model_file = args.model
  27. ########################
  28. # 1. Get and prepare data
  29. ########################
  30. print("Preparing data...")
  31. dataset = pd.read_csv(p_data_file, header=None, sep=";")
  32. print("Dataset size : ", len(dataset))
  33. # default first shuffle of data
  34. dataset = shuffle(dataset)
  35. print("Reading all images data...")
  36. # getting number of chanel
  37. n_channels = len(dataset[1][1].split('::'))
  38. print("Number of channels : ", n_channels)
  39. img_width, img_height = cfg.keras_img_size
  40. # specify the number of dimensions
  41. if K.image_data_format() == 'channels_first':
  42. if n_channels > 1:
  43. input_shape = (1, n_channels, img_width, img_height)
  44. else:
  45. input_shape = (n_channels, img_width, img_height)
  46. else:
  47. if n_channels > 1:
  48. input_shape = (1, img_width, img_height, n_channels)
  49. else:
  50. input_shape = (img_width, img_height, n_channels)
  51. # `:` is the separator used for getting each img path
  52. if n_channels > 1:
  53. dataset[1] = dataset[1].apply(lambda x: [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')])
  54. else:
  55. dataset[1] = dataset[1].apply(lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE))
  56. # reshape array data
  57. dataset[1] = dataset[1].apply(lambda x: np.array(x).reshape(input_shape))
  58. # use of the whole data set for training
  59. x_dataset = dataset.iloc[:,1:]
  60. y_dataset = dataset.iloc[:,0]
  61. x_data = []
  62. for item in x_dataset.values:
  63. #print("Item is here", item)
  64. x_data.append(item[0])
  65. x_data = np.array(x_data)
  66. print("End of loading data..")
  67. #######################
  68. # 2. Getting model
  69. #######################
  70. with open(p_model_file, 'r') as f:
  71. json_model = json.load(f)
  72. model = model_from_json(json_model)
  73. model.load_weights(p_model_file.replace('.json', '.h5'))
  74. model.compile(loss='binary_crossentropy',
  75. optimizer='rmsprop')
  76. # Get results obtained from model
  77. y_data_prediction = model.predict(x_data)
  78. y_prediction = [1 if x > 0.5 else 0 for x in y_data_prediction]
  79. acc_score = accuracy_score(y_dataset, y_prediction)
  80. f1_data_score = f1_score(y_dataset, y_prediction)
  81. recall_data_score = recall_score(y_dataset, y_prediction)
  82. pres_score = precision_score(y_dataset, y_prediction)
  83. roc_score = roc_auc_score(y_dataset, y_prediction)
  84. # save model performance
  85. if not os.path.exists(cfg.results_information_folder):
  86. os.makedirs(cfg.results_information_folder)
  87. perf_file_path = os.path.join(cfg.results_information_folder, cfg.perf_prediction_model_path)
  88. # write header if necessary
  89. if not os.path.exists(perf_file_path):
  90. with open(perf_file_path, 'w') as f:
  91. f.write(cfg.perf_prediction_header_file)
  92. # add information into file
  93. with open(perf_file_path, 'a') as f:
  94. line = p_data_file + ';' + str(len(dataset)) + ';' + p_model_file + ';' + str(acc_score) + ';' + str(f1_data_score) + ';' + str(recall_data_score) + ';' + str(pres_score) + ';' + str(roc_score) + ';\n'
  95. f.write(line)
  96. if __name__== "__main__":
  97. main()