Parcourir la source

svm can now be used for selector

Jérôme BUISINE il y a 4 ans
Parent
commit
a2fb893050
1 fichiers modifiés avec 23 ajouts et 8 suppressions
  1. 23 8
      find_best_attributes_from.py

+ 23 - 8
find_best_attributes_from.py

@@ -84,29 +84,40 @@ def main():
 
 
     parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
     parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .test)', required=True)
     parser.add_argument('--choice', type=str, help='model choice from list of choices', choices=models_list, required=True)
     parser.add_argument('--choice', type=str, help='model choice from list of choices', choices=models_list, required=True)
+    parser.add_argument('--selector', type=str, help='kind of model to use for selecting', choices=['svm', 'tree'], default='tree')
     parser.add_argument('--length', type=str, help='max data length (need to be specify for evaluator)', required=True)
     parser.add_argument('--length', type=str, help='max data length (need to be specify for evaluator)', required=True)
+    parser.add_argument('--output', type=str, help='output name expected for model results', required=True)
 
 
     args = parser.parse_args()
     args = parser.parse_args()
 
 
     p_data_file = args.data
     p_data_file = args.data
     p_choice    = args.choice
     p_choice    = args.choice
+    p_selector  = args.selector
     p_length    = args.length
     p_length    = args.length
+    p_output    = args.output
 
 
     print(p_data_file)
     print(p_data_file)
 
 
     # load data from file
     # load data from file
     x_train, y_train, x_test, y_test = loadDataset(p_data_file)
     x_train, y_train, x_test, y_test = loadDataset(p_data_file)
 
 
-    
-    # clf = ExtraTreesClassifier(n_estimators=100)
-    # clf = clf.fit(x_train, y_train)
-    # print(clf.feature_importances_)
+    for i in (np.arange(11) + 5):
 
 
+        model_to_fit = None
+        # use of svm here to fit well model
+        if p_selector == 'tree':
+            model_to_fit = ExtraTreesClassifier(n_estimators=100)
 
 
-    for i in (np.arange(11) + 5):
+        elif p_selector == 'svm':
+            Cs = [0.001, 0.01, 0.1, 1, 10, 100, 1000]
+            gammas = [0.001, 0.01, 0.1, 5, 10, 100]
+            param_grid = {'kernel':['rbf'], 'C': Cs, 'gamma' : gammas}
 
 
+            svc = svm.SVC(probability=True, class_weight='balanced')
+            #clf = GridSearchCV(svc, param_grid, cv=5, verbose=1, scoring=my_accuracy_scorer, n_jobs=-1)
+            model_to_fit = GridSearchCV(svc, param_grid, cv=5, verbose=1, scoring='roc_auc', n_jobs=-1)
 
 
-        model = SelectFromModel(ExtraTreesClassifier(n_estimators=100), max_features=i)
+        model = SelectFromModel(model_to_fit, max_features=i)
         selector = model.fit(x_train, y_train)
         selector = model.fit(x_train, y_train)
 
 
         binary_selection = [ 0 if x < selector.threshold_ else 1 for x in selector.estimator_.feature_importances_ ]
         binary_selection = [ 0 if x < selector.threshold_ else 1 for x in selector.estimator_.feature_importances_ ]
@@ -120,8 +131,12 @@ def main():
         y_test_model = svm_model.predict(X_test_new)
         y_test_model = svm_model.predict(X_test_new)
         test_roc_auc = roc_auc_score(y_test, y_test_model)
         test_roc_auc = roc_auc_score(y_test, y_test_model)
         
         
-        with open('data/results/selectFromModel.csv', 'a') as f:
-            line = str(len(binary_selection)) + ';'
+        if not os.path.exists(cfg.output_results_folder):
+            os.makedirs(cfg.output_results_folder)
+
+        # save model results into file
+        with open(os.path.join(cfg.output_results_folder, p_output), 'a') as f:
+            line = str(i) + ';'
             line += str(test_roc_auc) + ';'
             line += str(test_roc_auc) + ';'
             
             
             for index, b in enumerate(binary_selection):
             for index, b in enumerate(binary_selection):