models.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # models imports
  2. from sklearn.model_selection import GridSearchCV
  3. from sklearn.linear_model import LogisticRegression
  4. from sklearn.ensemble import RandomForestClassifier, VotingClassifier
  5. from sklearn.neighbors import KNeighborsClassifier
  6. from sklearn.ensemble import GradientBoostingClassifier
  7. from sklearn.feature_selection import RFECV
  8. import sklearn.svm as svm
  9. def _get_best_model(X_train, y_train):
  10. Cs = [0.001, 0.01, 0.1, 1, 10, 100, 1000]
  11. gammas = [0.001, 0.01, 0.1, 1, 5, 10, 100]
  12. param_grid = {'kernel':['rbf'], 'C': Cs, 'gamma' : gammas}
  13. svc = svm.SVC(probability=True)
  14. clf = GridSearchCV(svc, param_grid, cv=10, scoring='accuracy', verbose=0)
  15. clf.fit(X_train, y_train)
  16. model = clf.best_estimator_
  17. return model
  18. def svm_model(X_train, y_train):
  19. return _get_best_model(X_train, y_train)
  20. def rfe_svm_model(X_train, y_train, n_components=1):
  21. Cs = [0.001, 0.01, 0.1, 1, 10, 100, 1000]
  22. gammas = [0.001, 0.01, 0.1, 1, 5, 10, 100]
  23. param_grid = [{'estimator__C': Cs, 'estimator__gamma' : gammas}]
  24. estimator = svm.SVC(kernel="linear")
  25. selector = RFECV(estimator, step=1, cv=5, verbose=0)
  26. clf = GridSearchCV(selector, param_grid, cv=10, verbose=1)
  27. clf.fit(X_train, y_train)
  28. return clf.best_estimator_
  29. def get_trained_model(choice, X_train, y_train):
  30. if choice == 'svm_model':
  31. return svm_model(X_train, y_train)
  32. if choice == 'ensemble_model':
  33. return ensemble_model(X_train, y_train)
  34. if choice == 'ensemble_model_v2':
  35. return ensemble_model_v2(X_train, y_train)
  36. if choice == 'rfe_svm_model':
  37. return rfe_svm_model(X_train, y_train)