model_prediction_data_rf.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import sys, os, argparse
  5. # image processing
  6. from PIL import Image
  7. from ipfml import utils
  8. from ipfml.processing import transform, segmentation
  9. import matplotlib.pyplot as plt
  10. from sklearn.model_selection import train_test_split
  11. from sklearn.model_selection import GridSearchCV
  12. from sklearn.linear_model import LogisticRegression
  13. from sklearn.ensemble import RandomForestClassifier, VotingClassifier
  14. import joblib
  15. import sklearn.svm as svm
  16. from sklearn.utils import shuffle
  17. from sklearn.metrics import accuracy_score, roc_auc_score
  18. from sklearn.model_selection import cross_val_score
  19. # model imports
  20. import joblib
  21. # modules and config imports
  22. sys.path.insert(0, '') # trick to enable import of main folder module
  23. def write_progress(progress):
  24. barWidth = 180
  25. output_str = "["
  26. pos = barWidth * progress
  27. for i in range(barWidth):
  28. if i < pos:
  29. output_str = output_str + "="
  30. elif i == pos:
  31. output_str = output_str + ">"
  32. else:
  33. output_str = output_str + " "
  34. output_str = output_str + "] " + str(int(progress * 100.0)) + " %\r"
  35. print(output_str)
  36. sys.stdout.write("\033[F")
  37. def loadDataset(filename, n_step = 20):
  38. ########################
  39. # 1. Get and prepare data
  40. ########################
  41. # scene_name; zone_id; image_index_end; label; data
  42. head, folder_data = os.path.split(filename)
  43. dataset_train = pd.read_csv(os.path.join(filename, folder_data + '.train'), header=None, sep=";")
  44. dataset_test = pd.read_csv(os.path.join(filename, folder_data + '.test'), header=None, sep=";")
  45. # default first shuffle of data
  46. dataset_train = shuffle(dataset_train)
  47. dataset_test = shuffle(dataset_test)
  48. dataset_train = dataset_train[dataset_train.iloc[:, 2] % n_step == 0]
  49. dataset_test = dataset_test[dataset_test.iloc[:, 2] % n_step == 0]
  50. # get dataset with equal number of classes occurences
  51. noisy_df_train = dataset_train[dataset_train.iloc[:, 3] == 1]
  52. not_noisy_df_train = dataset_train[dataset_train.iloc[:, 3] == 0]
  53. #nb_noisy_train = len(noisy_df_train.index)
  54. noisy_df_test = dataset_test[dataset_test.iloc[:, 3] == 1]
  55. not_noisy_df_test = dataset_test[dataset_test.iloc[:, 3] == 0]
  56. #nb_noisy_test = len(noisy_df_test.index)
  57. # use of all data
  58. final_df_train = pd.concat([not_noisy_df_train, noisy_df_train])
  59. final_df_test = pd.concat([not_noisy_df_test, noisy_df_test])
  60. # shuffle data another time
  61. final_df_train = shuffle(final_df_train)
  62. final_df_test = shuffle(final_df_test)
  63. # use of the whole data set for training
  64. x_dataset_train = final_df_train.iloc[:, 4:]
  65. x_dataset_test = final_df_test.iloc[:, 4:]
  66. y_dataset_train = final_df_train.iloc[:, 3]
  67. y_dataset_test = final_df_test.iloc[:, 3]
  68. return x_dataset_train, y_dataset_train, x_dataset_test, y_dataset_test
  69. def train_model(p_data_file, p_solution):
  70. x_dataset_train, y_dataset_train, x_dataset_test, y_dataset_test = loadDataset(p_data_file)
  71. # get indices of filters data to use (filters selection from solution)
  72. indices = []
  73. print(p_solution)
  74. for index, value in enumerate(p_solution):
  75. if value == 1:
  76. indices.append(index)
  77. print(f'Selected indices are: {indices}')
  78. print(f"Train dataset size {len(x_dataset_train)}")
  79. print(f"Test dataset size {len(x_dataset_test)}")
  80. x_dataset_train = x_dataset_train.iloc[:, indices]
  81. x_dataset_test = x_dataset_test.iloc[:, indices]
  82. print("-------------------------------------------")
  83. # model = mdl.get_trained_model(p_choice, x_dataset_train, y_dataset_train)
  84. model = RandomForestClassifier(n_estimators=500, class_weight='balanced', bootstrap=True, max_samples=0.75, n_jobs=-1)
  85. model.fit(x_dataset_train, y_dataset_train)
  86. #######################
  87. # 3. Fit model : use of cross validation to fit model
  88. #######################
  89. val_scores = cross_val_score(model, x_dataset_train, y_dataset_train, cv=5)
  90. print("Accuracy: %0.2f (+/- %0.2f)" % (val_scores.mean(), val_scores.std() * 2))
  91. ######################
  92. # 4. Metrics
  93. ######################
  94. y_train_model = model.predict(x_dataset_train)
  95. y_test_model = model.predict(x_dataset_test)
  96. train_accuracy = accuracy_score(y_dataset_train, y_train_model)
  97. test_accuracy = accuracy_score(y_dataset_test, y_test_model)
  98. train_auc = roc_auc_score(y_dataset_train, y_train_model)
  99. test_auc = roc_auc_score(y_dataset_test, y_test_model)
  100. ###################
  101. # 5. Output : Print and write all information in csv
  102. ###################
  103. print("Train dataset size ", len(x_dataset_train))
  104. print("Train acc: ", train_accuracy)
  105. print("Train AUC: ", train_auc)
  106. print("Test dataset size ", len(x_dataset_test))
  107. print("Test acc: ", test_accuracy)
  108. print("Test AUC: ", test_auc)
  109. return model
  110. def main():
  111. parser = argparse.ArgumentParser(description="Read and compute entropy data file")
  112. # parser.add_argument('--solution', type=str, help='entropy file data with estimated threshold to read and compute')
  113. parser.add_argument('--data', type=str, help='dataset filename prefiloc (without .train and .test)', required=True)
  114. # parser.add_argument('--dataset', type=str, help='datasets file to load and predict from')
  115. parser.add_argument('--solution', type=str, help='Data of solution to specify filters to use')
  116. parser.add_argument('--output', type=str, help="output folder")
  117. args = parser.parse_args()
  118. # p_model = args.model
  119. p_data_file = args.data
  120. p_output = args.output
  121. p_solution = list(map(int, args.solution.split(' ')))
  122. # 2. load model and compile it
  123. model = train_model(p_data_file, p_solution)
  124. # begin prediction
  125. if not os.path.exists(p_output):
  126. os.makedirs(p_output)
  127. scene_predictions = {}
  128. data_lines = []
  129. dataset_files = os.listdir(p_data_file)
  130. for filename in dataset_files:
  131. filename_path = os.path.join(p_data_file, filename)
  132. with open(filename_path, 'r') as f:
  133. for line in f.readlines():
  134. data_lines.append(line)
  135. nlines = len(data_lines)
  136. ncounter = 0
  137. for line in data_lines:
  138. data = line.split(';')
  139. scene_name = data[0]
  140. zone_index = int(data[1])
  141. if scene_name not in scene_predictions:
  142. scene_predictions[scene_name] = []
  143. for _ in range(16):
  144. scene_predictions[scene_name].append([])
  145. # prepare input data
  146. # ToDo check data input
  147. input_data = np.array([ l.replace('\n', '').split(' ') for l in data[4:] ], 'float32').flatten()
  148. # print(input_data.flatten())
  149. input_data = np.expand_dims(input_data, axis=0)
  150. prob = model.predict(input_data)[0]
  151. scene_predictions[scene_name][zone_index].append(prob)
  152. ncounter += 1
  153. write_progress(float(ncounter / nlines))
  154. # 6. save predictions results
  155. for key, blocks_predictions in scene_predictions.items():
  156. output_file = os.path.join(p_output, key + '.csv')
  157. f = open(output_file, 'w')
  158. for i, data in enumerate(blocks_predictions):
  159. f.write(key + ';')
  160. f.write(str(i) + ';')
  161. for v in data:
  162. f.write(str(v) + ';')
  163. f.write('\n')
  164. f.close()
  165. if __name__== "__main__":
  166. main()