train_model.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import sys, os, argparse
  5. import json
  6. # model imports
  7. import cnn_models as models
  8. from keras import backend as K
  9. from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
  10. # image processing imports
  11. import cv2
  12. from sklearn.utils import shuffle
  13. # config imports
  14. sys.path.insert(0, '') # trick to enable import of main folder module
  15. import custom_config as cfg
  16. def main():
  17. parser = argparse.ArgumentParser(description="Train Keras model and save it into .json file")
  18. parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
  19. parser.add_argument('--output', type=str, help='output file name desired for model (without .json extension)', required=True)
  20. parser.add_argument('--tl', type=int, help='use or not of transfer learning (`VGG network`)', default=0, choices=[0, 1])
  21. parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=cfg.keras_batch)
  22. parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=cfg.keras_epochs)
  23. parser.add_argument('--val_size', type=int, help='percent of validation data during training process', default=cfg.val_dataset_size)
  24. args = parser.parse_args()
  25. p_data_file = args.data
  26. p_output = args.output
  27. p_tl = args.tl
  28. p_batch_size = args.batch_size
  29. p_epochs = args.epochs
  30. p_val_size = args.val_size
  31. ########################
  32. # 1. Get and prepare data
  33. ########################
  34. print("Preparing data...")
  35. dataset_train = pd.read_csv(p_data_file + '.train', header=None, sep=";")
  36. dataset_test = pd.read_csv(p_data_file + '.test', header=None, sep=";")
  37. print("Train set size : ", len(dataset_train))
  38. print("Test set size : ", len(dataset_test))
  39. # default first shuffle of data
  40. dataset_train = shuffle(dataset_train)
  41. dataset_test = shuffle(dataset_test)
  42. print("Reading all images data...")
  43. # getting number of chanel
  44. n_channels = len(dataset_train[1][1].split('::'))
  45. print("Number of channels : ", n_channels)
  46. img_width, img_height = cfg.keras_img_size
  47. # specify the number of dimensions
  48. if K.image_data_format() == 'channels_first':
  49. if n_channels > 1:
  50. input_shape = (1, n_channels, img_width, img_height)
  51. else:
  52. input_shape = (n_channels, img_width, img_height)
  53. else:
  54. if n_channels > 1:
  55. input_shape = (1, img_width, img_height, n_channels)
  56. else:
  57. input_shape = (img_width, img_height, n_channels)
  58. # `:` is the separator used for getting each img path
  59. if n_channels > 1:
  60. dataset_train[1] = dataset_train[1].apply(lambda x: [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')])
  61. dataset_test[1] = dataset_test[1].apply(lambda x: [cv2.imread(path, cv2.IMREAD_GRAYSCALE) for path in x.split('::')])
  62. else:
  63. dataset_train[1] = dataset_train[1].apply(lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE))
  64. dataset_test[1] = dataset_test[1].apply(lambda x: cv2.imread(x, cv2.IMREAD_GRAYSCALE))
  65. # reshape array data
  66. dataset_train[1] = dataset_train[1].apply(lambda x: np.array(x).reshape(input_shape))
  67. dataset_test[1] = dataset_test[1].apply(lambda x: np.array(x).reshape(input_shape))
  68. # get dataset with equal number of classes occurences
  69. noisy_df_train = dataset_train[dataset_train.ix[:, 0] == 1]
  70. not_noisy_df_train = dataset_train[dataset_train.ix[:, 0] == 0]
  71. nb_noisy_train = len(noisy_df_train.index)
  72. noisy_df_test = dataset_test[dataset_test.ix[:, 0] == 1]
  73. not_noisy_df_test = dataset_test[dataset_test.ix[:, 0] == 0]
  74. nb_noisy_test = len(noisy_df_test.index)
  75. final_df_train = pd.concat([not_noisy_df_train[0:nb_noisy_train], noisy_df_train])
  76. final_df_test = pd.concat([not_noisy_df_test[0:nb_noisy_test], noisy_df_test])
  77. # shuffle data another time
  78. final_df_train = shuffle(final_df_train)
  79. final_df_test = shuffle(final_df_test)
  80. final_df_train_size = len(final_df_train.index)
  81. final_df_test_size = len(final_df_test.index)
  82. # use of the whole data set for training
  83. x_dataset_train = final_df_train.ix[:,1:]
  84. x_dataset_test = final_df_test.ix[:,1:]
  85. y_dataset_train = final_df_train.ix[:,0]
  86. y_dataset_test = final_df_test.ix[:,0]
  87. x_data_train = []
  88. for item in x_dataset_train.values:
  89. #print("Item is here", item)
  90. x_data_train.append(item[0])
  91. x_data_train = np.array(x_data_train)
  92. x_data_test = []
  93. for item in x_dataset_test.values:
  94. #print("Item is here", item)
  95. x_data_test.append(item[0])
  96. x_data_test = np.array(x_data_test)
  97. print("End of loading data..")
  98. print("Train set size (after balancing) : ", final_df_train_size)
  99. print("Test set size (after balancing) : ", final_df_test_size)
  100. #######################
  101. # 2. Getting model
  102. #######################
  103. model = models.get_model(n_channels, input_shape, p_tl)
  104. model.summary()
  105. model.fit(x_data_train, y_dataset_train.values, validation_split=p_val_size, epochs=p_epochs, batch_size=p_batch_size)
  106. score = model.evaluate(x_data_test, y_dataset_test, batch_size=p_batch_size)
  107. print("Accuracy score on test dataset ", score)
  108. if not os.path.exists(cfg.saved_models_folder):
  109. os.makedirs(cfg.saved_models_folder)
  110. # save the model into HDF5 file
  111. model_output_path = os.path.join(cfg.saved_models_folder, p_output + '.json')
  112. json_model_content = model.to_json()
  113. with open(model_output_path, 'w') as f:
  114. print("Model saved into ", model_output_path)
  115. json.dump(json_model_content, f, indent=4)
  116. model.save_weights(model_output_path.replace('.json', '.h5'))
  117. # Get results obtained from model
  118. y_train_prediction = model.predict(x_data_train)
  119. y_test_prediction = model.predict(x_data_test)
  120. y_train_prediction = [1 if x > 0.5 else 0 for x in y_train_prediction]
  121. y_test_prediction = [1 if x > 0.5 else 0 for x in y_test_prediction]
  122. acc_train_score = accuracy_score(y_dataset_train, y_train_prediction)
  123. acc_test_score = accuracy_score(y_dataset_test, y_test_prediction)
  124. f1_train_score = f1_score(y_dataset_train, y_train_prediction)
  125. f1_test_score = f1_score(y_dataset_test, y_test_prediction)
  126. recall_train_score = recall_score(y_dataset_train, y_train_prediction)
  127. recall_test_score = recall_score(y_dataset_test, y_test_prediction)
  128. pres_train_score = precision_score(y_dataset_train, y_train_prediction)
  129. pres_test_score = precision_score(y_dataset_test, y_test_prediction)
  130. roc_train_score = roc_auc_score(y_dataset_train, y_train_prediction)
  131. roc_test_score = roc_auc_score(y_dataset_test, y_test_prediction)
  132. # save model performance
  133. if not os.path.exists(cfg.results_information_folder):
  134. os.makedirs(cfg.results_information_folder)
  135. perf_file_path = os.path.join(cfg.results_information_folder, cfg.csv_model_comparisons_filename)
  136. with open(perf_file_path, 'a') as f:
  137. line = p_output + ';' + str(len(dataset_train)) + ';' + str(len(dataset_test)) + ';' \
  138. + str(final_df_train_size) + ';' + str(final_df_test_size) + ';' \
  139. + str(acc_train_score) + ';' + str(acc_test_score) + ';' \
  140. + str(f1_train_score) + ';' + str(f1_test_score) + ';' \
  141. + str(recall_train_score) + ';' + str(recall_test_score) + ';' \
  142. + str(pres_train_score) + ';' + str(pres_test_score) + ';' \
  143. + str(roc_train_score) + ';' + str(roc_test_score) + '\n'
  144. f.write(line)
  145. if __name__== "__main__":
  146. main()