svm_model_train.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from sklearn.model_selection import train_test_split
  2. from sklearn.model_selection import GridSearchCV
  3. import sklearn.svm as svm
  4. from sklearn.externals import joblib
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.metrics import accuracy_score
  8. import sys, os, getopt
  9. output_model_folder = './saved_models/'
  10. def get_best_model(X_train, y_train):
  11. parameters = {'kernel':['rbf'], 'C': np.arange(1, 20)}
  12. svc = svm.SVC(gamma="scale")
  13. clf = GridSearchCV(svc, parameters, cv=5, scoring='accuracy', verbose=10)
  14. clf.fit(X_train, y_train)
  15. model = clf.best_estimator_
  16. return model
  17. def main():
  18. if len(sys.argv) <= 1:
  19. print('Run with default parameters...')
  20. print('python smv_model_train.py --data xxxx --output xxxx')
  21. sys.exit(2)
  22. try:
  23. opts, args = getopt.getopt(sys.argv[1:], "hd:o", ["help=", "data=", "output="])
  24. except getopt.GetoptError:
  25. # print help information and exit:
  26. print('python smv_model_train.py --data xxxx --output xxxx')
  27. sys.exit(2)
  28. for o, a in opts:
  29. if o == "-h":
  30. print('python smv_model_train.py --data xxxx --output xxxx')
  31. sys.exit()
  32. elif o in ("-d", "--data"):
  33. p_data_file = a
  34. elif o in ("-o", "--output"):
  35. p_output = a
  36. else:
  37. assert False, "unhandled option"
  38. if not os.path.exists(output_model_folder):
  39. os.makedirs(output_model_folder)
  40. dataset = pd.read_csv(p_data_file, header=None, sep=";")
  41. y_dataset = dataset.ix[:,0]
  42. x_dataset = dataset.ix[:,1:]
  43. X_train, X_test, y_train, y_test = train_test_split(x_dataset, y_dataset, test_size=0.4, random_state=42)
  44. svm_model = get_best_model(X_train, y_train)
  45. y_pred = svm_model.predict(X_test)
  46. print("Accuracy found %s " % str(accuracy_score(y_test, y_pred)))
  47. joblib.dump(svm_model, output_model_folder + p_output + '.joblib')
  48. if __name__== "__main__":
  49. main()