find_best_attributes_surrogate.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # main imports
  2. import os
  3. import sys
  4. import argparse
  5. import pandas as pd
  6. import numpy as np
  7. import logging
  8. import datetime
  9. import random
  10. # model imports
  11. from sklearn.model_selection import train_test_split
  12. from sklearn.model_selection import GridSearchCV
  13. from sklearn.linear_model import LogisticRegression
  14. from sklearn.ensemble import RandomForestClassifier, VotingClassifier
  15. import joblib
  16. import sklearn.svm as svm
  17. from sklearn.utils import shuffle
  18. from sklearn.metrics import roc_auc_score
  19. from sklearn.model_selection import cross_val_score
  20. # modules and config imports
  21. sys.path.insert(0, '') # trick to enable import of main folder module
  22. import custom_config as cfg
  23. import models as mdl
  24. from optimization.ILSSurrogate import ILSSurrogate
  25. from macop.solutions.BinarySolution import BinarySolution
  26. from macop.operators.mutators.SimpleMutation import SimpleMutation
  27. from macop.operators.mutators.SimpleBinaryMutation import SimpleBinaryMutation
  28. from macop.operators.crossovers.SimpleCrossover import SimpleCrossover
  29. from macop.operators.crossovers.RandomSplitCrossover import RandomSplitCrossover
  30. from macop.operators.policies.UCBPolicy import UCBPolicy
  31. from macop.callbacks.BasicCheckpoint import BasicCheckpoint
  32. from macop.callbacks.UCBCheckpoint import UCBCheckpoint
  33. # variables and parameters
  34. models_list = cfg.models_names_list
  35. # default validator
  36. def validator(solution):
  37. # at least 5 attributes
  38. if list(solution.data).count(1) < 5:
  39. return False
  40. return True
  41. def loadDataset(filename):
  42. ########################
  43. # 1. Get and prepare data
  44. ########################
  45. dataset_train = pd.read_csv(filename + '.train', header=None, sep=";")
  46. dataset_test = pd.read_csv(filename + '.test', header=None, sep=";")
  47. # default first shuffle of data
  48. dataset_train = shuffle(dataset_train)
  49. dataset_test = shuffle(dataset_test)
  50. # get dataset with equal number of classes occurences
  51. noisy_df_train = dataset_train[dataset_train.iloc[:, 0] == 1]
  52. not_noisy_df_train = dataset_train[dataset_train.iloc[:, 0] == 0]
  53. #nb_noisy_train = len(noisy_df_train.index)
  54. noisy_df_test = dataset_test[dataset_test.iloc[:, 0] == 1]
  55. not_noisy_df_test = dataset_test[dataset_test.iloc[:, 0] == 0]
  56. #nb_noisy_test = len(noisy_df_test.index)
  57. # use of all data
  58. final_df_train = pd.concat([not_noisy_df_train, noisy_df_train])
  59. final_df_test = pd.concat([not_noisy_df_test, noisy_df_test])
  60. # shuffle data another time
  61. final_df_train = shuffle(final_df_train)
  62. final_df_test = shuffle(final_df_test)
  63. # use of the whole data set for training
  64. x_dataset_train = final_df_train.iloc[:,1:]
  65. x_dataset_test = final_df_test.iloc[:,1:]
  66. y_dataset_train = final_df_train.iloc[:,0]
  67. y_dataset_test = final_df_test.iloc[:,0]
  68. return x_dataset_train, y_dataset_train, x_dataset_test, y_dataset_test
  69. def main():
  70. parser = argparse.ArgumentParser(description="Train and find best filters to use for model")
  71. parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
  72. parser.add_argument('--choice', type=str, help='model choice from list of choices', choices=models_list, required=True)
  73. parser.add_argument('--length', type=int, help='max data length (need to be specify for evaluator)', required=True)
  74. parser.add_argument('--surrogate', type=str, help='surrogate .joblib model to approximate fitness', required=True)
  75. parser.add_argument('--solutions', type=str, help='solutions files required to find surrogate model', required=True)
  76. parser.add_argument('--ils', type=int, help='number of total iteration for ils algorithm', required=True)
  77. parser.add_argument('--ls', type=int, help='number of iteration for Local Search algorithm', required=True)
  78. args = parser.parse_args()
  79. p_data_file = args.data
  80. p_choice = args.choice
  81. p_length = args.length
  82. p_surrogate = args.surrogate
  83. p_solutions = args.solutions
  84. p_ils_iteration = args.ils
  85. p_ls_iteration = args.ls
  86. print(p_data_file)
  87. # load data from file
  88. x_train, y_train, x_test, y_test = loadDataset(p_data_file)
  89. # create `logs` folder if necessary
  90. if not os.path.exists(cfg.output_logs_folder):
  91. os.makedirs(cfg.output_logs_folder)
  92. _, data_file_name = os.path.split(p_data_file)
  93. logging.basicConfig(format='%(asctime)s %(message)s', filename='data/logs/{0}.log'.format(data_file_name), level=logging.DEBUG)
  94. # init solution (`n` attributes)
  95. def init():
  96. return BinarySolution([], p_length
  97. ).random(validator)
  98. # define evaluate function here (need of data information)
  99. def evaluate(solution, use_surrogate=True):
  100. start = datetime.datetime.now()
  101. # get indices of filters data to use (filters selection from solution)
  102. indices = []
  103. for index, value in enumerate(solution.data):
  104. if value == 1:
  105. indices.append(index)
  106. # keep only selected filters from solution
  107. x_train_filters = x_train.iloc[:, indices]
  108. y_train_filters = y_train
  109. x_test_filters = x_test.iloc[:, indices]
  110. # TODO : use of GPU implementation of SVM
  111. model = mdl.get_trained_model(p_choice, x_train_filters, y_train_filters)
  112. y_test_model = model.predict(x_test_filters)
  113. test_roc_auc = roc_auc_score(y_test, y_test_model)
  114. end = datetime.datetime.now()
  115. diff = end - start
  116. print("Real evaluation took: {}, score found: {}".format(divmod(diff.days * 86400 + diff.seconds, 60), test_roc_auc))
  117. return test_roc_auc
  118. backup_model_folder = os.path.join(cfg.output_backup_folder, data_file_name)
  119. if not os.path.exists(backup_model_folder):
  120. os.makedirs(backup_model_folder)
  121. backup_file_path = os.path.join(backup_model_folder, data_file_name + '.csv')
  122. ucb_backup_file_path = os.path.join(backup_model_folder, data_file_name + '_ucbPolicy.csv')
  123. # prepare optimization algorithm
  124. operators = [SimpleBinaryMutation(), SimpleMutation(), SimpleCrossover(), RandomSplitCrossover()]
  125. policy = UCBPolicy(operators)
  126. # custom ILS for surrogate use
  127. algo = ILSSurrogate(_initalizer=init,
  128. _evaluator=None, # by default no evaluator, as we will use the surrogate function
  129. _operators=operators,
  130. _policy=policy,
  131. _validator=validator,
  132. _surrogate_file_path=p_surrogate,
  133. _solutions_file=p_solutions,
  134. _ls_train_surrogate=1,
  135. _real_evaluator=evaluate,
  136. _maximise=True)
  137. algo.addCallback(BasicCheckpoint(_every=1, _filepath=backup_file_path))
  138. algo.addCallback(UCBCheckpoint(_every=1, _filepath=ucb_backup_file_path))
  139. bestSol = algo.run(p_ils_iteration, p_ls_iteration)
  140. # print best solution found
  141. print("Found ", bestSol)
  142. # save model information into .csv file
  143. if not os.path.exists(cfg.results_information_folder):
  144. os.makedirs(cfg.results_information_folder)
  145. filename_path = os.path.join(cfg.results_information_folder, cfg.optimization_attributes_result_filename)
  146. filters_counter = 0
  147. # count number of filters
  148. for index, item in enumerate(bestSol.data):
  149. if index != 0 and index % 2 == 1:
  150. # if two attributes are used
  151. if item == 1 or bestSol.data[index - 1] == 1:
  152. filters_counter += 1
  153. line_info = p_data_file + ';' + str(p_ils_iteration) + ';' + str(p_ls_iteration) + ';' + str(bestSol.data) + ';' + str(list(bestSol.data).count(1)) + ';' + str(filters_counter) + ';' + str(bestSol.fitness())
  154. with open(filename_path, 'a') as f:
  155. f.write(line_info + '\n')
  156. print('Result saved into %s' % filename_path)
  157. if __name__ == "__main__":
  158. main()