浏览代码

use of random forest

Jérôme BUISINE 3 年之前
父节点
当前提交
38ca51bff9
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 3 3
      find_best_attributes_surrogate.py

+ 3 - 3
find_best_attributes_surrogate.py

@@ -175,9 +175,9 @@ def main():
             y_train_filters = self._data['y_train']
             x_test_filters = self._data['x_test'].iloc[:, indices]
             
-            model = _get_best_model(x_train_filters, y_train_filters)
-            # model = RandomForestClassifier(n_estimators=500, class_weight='balanced', bootstrap=True, max_samples=0.75, n_jobs=-1)
-            # model = model.fit(x_train_filters, y_train_filters)
+            # model = _get_best_model(x_train_filters, y_train_filters)
+            model = RandomForestClassifier(n_estimators=500, class_weight='balanced', bootstrap=True, max_samples=0.75, n_jobs=-1)
+            model = model.fit(x_train_filters, y_train_filters)
             
             y_test_model = model.predict(x_test_filters)
             test_roc_auc = roc_auc_score(self._data['y_test'], y_test_model)