ensemble_model_train.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from sklearn.model_selection import train_test_split
  2. from sklearn.model_selection import GridSearchCV
  3. from sklearn.linear_model import LogisticRegression
  4. from sklearn.ensemble import RandomForestClassifier, VotingClassifier
  5. import sklearn.svm as svm
  6. from sklearn.externals import joblib
  7. import numpy as np
  8. import pandas as pd
  9. from sklearn.metrics import accuracy_score
  10. import sys, os, getopt
  11. output_model_folder = './saved_models/'
  12. def get_best_model(X_train, y_train):
  13. parameters = {'kernel':['rbf'], 'C': np.arange(1, 20)}
  14. svc = svm.SVC(gamma="scale")
  15. clf = GridSearchCV(svc, parameters, cv=5, scoring='accuracy', verbose=10)
  16. clf.fit(X_train, y_train)
  17. model = clf.best_estimator_
  18. return model
  19. def main():
  20. if len(sys.argv) <= 1:
  21. print('Run with default parameters...')
  22. print('python smv_model_train.py --data xxxx --output xxxx')
  23. sys.exit(2)
  24. try:
  25. opts, args = getopt.getopt(sys.argv[1:], "hd:o", ["help=", "data=", "output="])
  26. except getopt.GetoptError:
  27. # print help information and exit:
  28. print('python smv_model_train.py --data xxxx --output xxxx')
  29. sys.exit(2)
  30. for o, a in opts:
  31. if o == "-h":
  32. print('python smv_model_train.py --data xxxx --output xxxx')
  33. sys.exit()
  34. elif o in ("-d", "--data"):
  35. p_data_file = a
  36. elif o in ("-o", "--output"):
  37. p_output = a
  38. else:
  39. assert False, "unhandled option"
  40. if not os.path.exists(output_model_folder):
  41. os.makedirs(output_model_folder)
  42. # get and split data
  43. dataset = pd.read_csv(p_data_file, header=None, sep=";")
  44. y_dataset = dataset.ix[:,0]
  45. x_dataset = dataset.ix[:,1:]
  46. X_train, X_test, y_train, y_test = train_test_split(x_dataset, y_dataset, test_size=0.4, random_state=42)
  47. svm_model = get_best_model(X_train, y_train)
  48. lr_model = LogisticRegression(solver='lbfgs', multi_class='multinomial', random_state=1)
  49. rf_model = RandomForestClassifier(n_estimators=50, random_state=1)
  50. ensemble_model = VotingClassifier(estimators=[
  51. ('svm', svm_model), ('lr', lr_model), ('rf', rf_model)],
  52. voting='soft', weights=[2,1,1],
  53. flatten_transform=True)
  54. ensemble_model.fit(X_train, y_train)
  55. y_pred = ensemble_model.predict(X_test)
  56. print("Accuracy found %s " % str(accuracy_score(y_test, y_pred)))
  57. joblib.dump(ensemble_model, output_model_folder + p_output + '.joblib')
  58. if __name__== "__main__":
  59. main()