train_model.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import numpy as np
  2. import pandas as pd
  3. import os, sys, argparse
  4. from sklearn import linear_model
  5. from sklearn import svm
  6. from sklearn.utils import shuffle
  7. import modules.config as cfg
  8. import modules.metrics as metrics
  9. from joblib import dump, load
  10. def get_model_choice(_model_name):
  11. """
  12. Bind choose model using String information
  13. """
  14. if _model_name == "SGD":
  15. clf = linear_model.SGDRegressor(max_iter=1000, tol=1e-3)
  16. if _model_name == "Ridge":
  17. clf = linear_model.Ridge(alpha=1.)
  18. if _model_name == "SVR":
  19. clf = svm.SVR()
  20. return clf
  21. def train(_data_file, _model_name):
  22. # prepare data
  23. dataset = pd.read_csv(_data_file, header=None, sep=";")
  24. dataset = shuffle(dataset)
  25. y = dataset.ix[:,0]
  26. X = dataset.ix[:,1:]
  27. clf = get_model_choice(_model_name)
  28. clf.fit(X, y)
  29. y_predicted = clf.predict(X)
  30. coeff = metrics.coefficient_of_determination(y, y_predicted)
  31. print("Predicted coefficient of determination for ", _model_name, " : ", coeff)
  32. # save the trained model, so check if saved folder exists
  33. if not os.path.exists(cfg.saved_models_folder):
  34. os.makedirs(cfg.saved_models_folder)
  35. # compute model filename
  36. model_filename = _data_file.split('/')[-1].replace(cfg.output_file_prefix, '').replace('.csv', '')
  37. model_filename = model_filename + '_' + _model_name + '.joblib'
  38. model_file_path = os.path.join(cfg.saved_models_folder, model_filename)
  39. print("Model will be save into `", model_file_path, '`')
  40. dump(clf, model_file_path)
  41. # save score into global_result.csv file
  42. with open(cfg.global_result_filepath, "a") as f:
  43. f.write(model_filename.replace('.joblib', '') + ';' + str(len(y)) + ';' + str(coeff) + ';\n')
  44. def main():
  45. parser = argparse.ArgumentParser(description="Train model and saved it")
  46. parser.add_argument('--data', type=str, help='Filename of dataset')
  47. parser.add_argument('--model', type=str, help='Kind of model expected', choices=cfg.kind_of_models)
  48. args = parser.parse_args()
  49. param_data_file = args.data
  50. param_model = args.model
  51. train(param_data_file, param_model)
  52. if __name__== "__main__":
  53. main()