prediction_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import sys, os, argparse
  5. import json
  6. # model imports
  7. import cnn_models as models
  8. import tensorflow as tf
  9. import keras
  10. from keras import backend as K
  11. from keras.callbacks import ModelCheckpoint
  12. from keras.models import model_from_json
  13. from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
  14. # image processing imports
  15. import cv2
  16. from sklearn.utils import shuffle
  17. import seaborn as sns
  18. import matplotlib.pyplot as plt
  19. # config imports
  20. sys.path.insert(0, '') # trick to enable import of main folder module
  21. import custom_config as cfg
  22. def main():
  23. parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
  24. parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
  25. parser.add_argument('--model', type=str, help='.json file of keras model')
  26. args = parser.parse_args()
  27. p_data_file = args.data
  28. p_model_file = args.model
  29. ########################
  30. # 1. Get and prepare data
  31. ########################
  32. print("Preparing data...")
  33. dataset = pd.read_csv(p_data_file, header=None, sep=";")
  34. print("Dataset size : ", len(dataset))
  35. # default first shuffle of data
  36. dataset = shuffle(dataset)
  37. print("Reading all images data...")
  38. # getting number of chanel
  39. n_channels = len(dataset[1][1].split('::'))
  40. print("Number of channels : ", n_channels)
  41. img_width, img_height = cfg.keras_img_size
  42. # specify the number of dimensions
  43. if K.image_data_format() == 'channels_first':
  44. if n_channels > 1:
  45. input_shape = (1, n_channels, img_width, img_height)
  46. else:
  47. input_shape = (n_channels, img_width, img_height)
  48. else:
  49. if n_channels > 1:
  50. input_shape = (1, img_width, img_height, n_channels)
  51. else:
  52. input_shape = (img_width, img_height, n_channels)
  53. # `:` is the separator used for getting each img path
  54. if n_channels > 1:
  55. dataset[1] = dataset[1].apply(lambda x: [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')])
  56. else:
  57. dataset[1] = dataset[1].apply(lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE))
  58. # reshape array data
  59. dataset[1] = dataset[1].apply(lambda x: np.array(x).reshape(input_shape))
  60. # use of the whole data set for training
  61. x_dataset = dataset.iloc[:,1:]
  62. y_dataset = dataset.iloc[:,0]
  63. x_data = []
  64. for item in x_dataset.values:
  65. #print("Item is here", item)
  66. x_data.append(item[0])
  67. x_data = np.array(x_data)
  68. print("End of loading data..")
  69. #######################
  70. # 2. Getting model
  71. #######################
  72. with open(p_model_file, 'r') as f:
  73. json_model = json.load(f)
  74. model = model_from_json(json_model)
  75. model.load_weights(p_model_file.replace('.json', '.h5'))
  76. model.compile(loss='categorical_crossentropy',
  77. optimizer='adam')
  78. # Get results obtained from model
  79. y_data_prediction = model.predict(x_data)
  80. y_prediction = np.argmax(y_data_prediction, axis=1)
  81. acc_score = accuracy_score(y_dataset, y_prediction)
  82. f1_data_score = f1_score(y_dataset, y_prediction)
  83. recall_data_score = recall_score(y_dataset, y_prediction)
  84. pres_score = precision_score(y_dataset, y_prediction)
  85. roc_score = roc_auc_score(y_dataset, y_prediction)
  86. # save model performance
  87. if not os.path.exists(cfg.results_information_folder):
  88. os.makedirs(cfg.results_information_folder)
  89. perf_file_path = os.path.join(cfg.results_information_folder, cfg.perf_prediction_model_path)
  90. # write header if necessary
  91. if not os.path.exists(perf_file_path):
  92. with open(perf_file_path, 'w') as f:
  93. f.write(cfg.perf_prediction_header_file)
  94. # add information into file
  95. with open(perf_file_path, 'a') as f:
  96. line = p_data_file + ';' + str(len(dataset)) + ';' + p_model_file + ';' + str(acc_score) + ';' + str(f1_data_score) + ';' + str(recall_data_score) + ';' + str(pres_score) + ';' + str(roc_score) + ';\n'
  97. f.write(line)
  98. if __name__== "__main__":
  99. main()