浏览代码

fix with use of numpy for mean

Jérôme BUISINE 4 年之前
父节点
当前提交
38f647b7fb
共有 1 个文件被更改,包括 3 次插入0 次删除
  1. 3 0
      models.py

+ 3 - 0
models.py

@@ -1,4 +1,6 @@
 # models imports
+import numpy as np
+
 from sklearn.model_selection import GridSearchCV
 from sklearn.linear_model import LogisticRegression
 from sklearn.ensemble import RandomForestClassifier, VotingClassifier
@@ -68,6 +70,7 @@ def _get_best_gpu_model(X_train, y_train):
             svc.fit(X_train, y_train)
 
             score = cross_val_score(svc, X_train, y_train, cv=k_fold, n_jobs=-1)
+            score = np.mean(score)
 
             # keep track of best model
             if score > bestScore: