Browse Source

use of random forest classifier now

Jérôme BUISINE 1 year ago
parent
commit
3a3a55bbd9
1 changed files with 6 additions and 6 deletions
  1. 6 6
      find_best_attributes_surrogate.py

+ 6 - 6
find_best_attributes_surrogate.py

@@ -154,7 +154,7 @@ def main():
         return BinarySolution.random(p_length, validator)
 
 
-    class SVMEvaluator(Evaluator):
+    class RandomForestEvaluator(Evaluator):
 
         # define evaluate function here (need of data information)
         def compute(self, solution):
@@ -172,9 +172,9 @@ def main():
             y_train_filters = self._data['y_train']
             x_test_filters = self._data['x_test'].iloc[:, indices]
             
-            model = _get_best_model(x_train_filters, y_train_filters)
-            #model = RandomForestClassifier(n_estimators=10)
-            #model = model.fit(x_train_filters, y_train_filters)
+            # model = _get_best_model(x_train_filters, y_train_filters)
+            model = RandomForestClassifier(n_estimators=300, class_weight='balanced', n_jobs=-1)
+            model = model.fit(x_train_filters, y_train_filters)
             
             y_test_model = model.predict(x_test_filters)
             test_roc_auc = roc_auc_score(self._data['y_test'], y_test_model)
@@ -183,7 +183,7 @@ def main():
 
             diff = end - start
 
-            #print("Real evaluation took: {}, score found: {}".format(divmod(diff.days * 86400 + diff.seconds, 60), test_roc_auc))
+            print("Real evaluation took: {}, score found: {}".format(divmod(diff.days * 86400 + diff.seconds, 60), test_roc_auc))
 
             return test_roc_auc
 
@@ -216,7 +216,7 @@ def main():
 
     # custom ILS for surrogate use
     algo = ILSPopSurrogate(initalizer=init, 
-                        evaluator=SVMEvaluator(data={'x_train': x_train, 'y_train': y_train, 'x_test': x_test, 'y_test': y_test}), # same evaluator by default, as we will use the surrogate function
+                        evaluator=RandomForestEvaluator(data={'x_train': x_train, 'y_train': y_train, 'x_test': x_test, 'y_test': y_test}), # same evaluator by default, as we will use the surrogate function
                         operators=operators, 
                         policy=policy, 
                         validator=validator,