train_model.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # main imports
  2. import numpy as np
  3. import pandas as pd
  4. import os, sys, argparse
  5. # model imports
  6. from sklearn import linear_model
  7. from sklearn import svm
  8. from sklearn.utils import shuffle
  9. from joblib import dump, load
  10. # modules and config imports
  11. sys.path.insert(0, '') # trick to enable import of main folder module
  12. import custom_config as cfg
  13. from ipfml import metrics
  14. def get_model_choice(_model_name):
  15. """
  16. Bind choose model using String information
  17. """
  18. if _model_name == "SGD":
  19. clf = linear_model.SGDRegressor(max_iter=1000, tol=1e-3)
  20. if _model_name == "Ridge":
  21. clf = linear_model.Ridge(alpha=1.)
  22. if _model_name == "SVR":
  23. clf = svm.SVR()
  24. return clf
  25. def train(_data_file, _model_name):
  26. # prepare data
  27. dataset = pd.read_csv(_data_file, header=None, sep=";")
  28. dataset = shuffle(dataset)
  29. y = dataset.ix[:,0]
  30. X = dataset.ix[:,1:]
  31. clf = get_model_choice(_model_name)
  32. clf.fit(X, y)
  33. y_predicted = clf.predict(X)
  34. coeff = metrics.coefficient_of_determination(y, y_predicted)
  35. print("Predicted coefficient of determination for ", _model_name, " : ", coeff)
  36. # save the trained model, so check if saved folder exists
  37. if not os.path.exists(cfg.saved_models_folder):
  38. os.makedirs(cfg.saved_models_folder)
  39. # compute model filename_colum,n
  40. model_filename = _data_file.split('/')[-1].replace(cfg.output_file_prefix, '').replace('.csv', '')
  41. model_filename = model_filename + '_' + _model_name + '.joblib'
  42. model_file_path = os.path.join(cfg.saved_models_folder, model_filename)
  43. print("Model will be save into `", model_file_path, '`')
  44. dump(clf, model_file_path)
  45. if not os.path.exists(cfg.results_information_folder):
  46. os.makedirs(cfg.results_information_folder)
  47. # save score into global_result.csv file
  48. with open(cfg.global_result_filepath, "a") as f:
  49. f.write(model_filename.replace('.joblib', '') + ';' + str(len(y)) + ';' + str(coeff) + ';\n')
  50. def main():
  51. parser = argparse.ArgumentParser(description="Train model and saved it")
  52. parser.add_argument('--data', type=str, help='Filename of dataset')
  53. parser.add_argument('--model', type=str, help='Kind of model expected', choices=cfg.kind_of_models)
  54. args = parser.parse_args()
  55. param_data_file = args.data
  56. param_model = args.model
  57. train(param_data_file, param_model)
  58. if __name__== "__main__":
  59. main()