Browse Source

Update of scorer used

Jérôme BUISINE 10 months ago
parent
commit
da6c1aea12
2 changed files with 10 additions and 10 deletions
  1. 1 1
      data_processing/generateAndTrain_maxwell_custom.sh
  2. 9 9
      models.py

+ 1 - 1
data_processing/generateAndTrain_maxwell_custom.sh

@@ -47,7 +47,7 @@ for nb_zones in {10,11,12}; do
         MODEL_NAME="${model}_N${size}_B${start}_E${end}_nb_zones_${nb_zones}_${feature}_${mode}_${data}"
         CUSTOM_MIN_MAX_FILENAME="N${size}_B${start}_E${end}_nb_zones_${nb_zones}_${feature}_${mode}_${data}_min_max"
 
-        echo $FILENAME
+        # echo $FILENAME
 
         # only compute if necessary (perhaps server will fall.. Just in case)
         if grep -q "${MODEL_NAME}" "${result_file_path}"; then

+ 9 - 9
models.py

@@ -5,8 +5,14 @@ from sklearn.ensemble import RandomForestClassifier, VotingClassifier
 from sklearn.neighbors import KNeighborsClassifier
 from sklearn.ensemble import GradientBoostingClassifier
 from sklearn.feature_selection import RFECV
+from sklearn.metrics import roc_auc_score
 import sklearn.svm as svm
 
+def _roc_auc_scorer(estimator, X, y):
+    
+    y_pred = estimator.predict(X)
+    
+    return roc_auc_score(y, y_pred)
 
 def _get_best_model(X_train, y_train):
 
@@ -15,7 +21,7 @@ def _get_best_model(X_train, y_train):
     param_grid = {'kernel':['rbf'], 'C': Cs, 'gamma' : gammas}
 
     svc = svm.SVC(probability=True)
-    clf = GridSearchCV(svc, param_grid, cv=10, scoring='accuracy', verbose=0)
+    clf = GridSearchCV(svc, param_grid, cv=10, scoring=_roc_auc_scorer, verbose=0)
 
     clf.fit(X_train, y_train)
 
@@ -34,8 +40,8 @@ def rfe_svm_model(X_train, y_train, n_components=1):
     param_grid = [{'estimator__C': Cs, 'estimator__gamma' : gammas}]
 
     estimator = svm.SVC(kernel="linear")
-    selector = RFECV(estimator, step=1, cv=5, verbose=0)
-    clf = GridSearchCV(selector, param_grid, cv=10, verbose=1)
+    selector = RFECV(estimator, step=1, cv=4, verbose=1)
+    clf = GridSearchCV(selector, param_grid, cv=5, verbose=1, scoring=_roc_auc_scorer)
     clf.fit(X_train, y_train)
 
     return clf.best_estimator_
@@ -46,11 +52,5 @@ def get_trained_model(choice, X_train, y_train):
     if choice == 'svm_model':
         return svm_model(X_train, y_train)
 
-    if choice == 'ensemble_model':
-        return ensemble_model(X_train, y_train)
-
-    if choice == 'ensemble_model_v2':
-        return ensemble_model_v2(X_train, y_train)
-
     if choice == 'rfe_svm_model':
         return rfe_svm_model(X_train, y_train)