|
@@ -53,22 +53,13 @@ def train(_data_file, _model_name):
|
|
|
model.summary()
|
|
|
|
|
|
# Set expected metrics
|
|
|
- # TODO : add coefficients of determination as metric
|
|
|
+ # TODO : add coefficients of determination as metric ? Or always use MSE/MAE
|
|
|
model.compile(loss='mse', optimizer='adam', metrics=['mse', 'mae'])
|
|
|
- history = model.fit(X_train, y_train, epochs=1, batch_size=50, verbose=1, validation_split=0.2)
|
|
|
+ history = model.fit(X_train, y_train, epochs=150, batch_size=50, verbose=1, validation_split=0.2)
|
|
|
|
|
|
# Save model
|
|
|
print(history.history.keys())
|
|
|
|
|
|
- # TODO : Save plot info and increase figure size
|
|
|
- plt.plot(history.history['loss'])
|
|
|
- plt.plot(history.history['val_loss'])
|
|
|
- plt.title('model loss', fontsize=20)
|
|
|
- plt.ylabel('loss', fontsize=16)
|
|
|
- plt.xlabel('epoch', fontsize=16)
|
|
|
- plt.legend(['train', 'validation'], loc='upper left', fontsize=16)
|
|
|
- #plt.show()
|
|
|
-
|
|
|
y_predicted = model.predict(X_test)
|
|
|
len_shape, _ = y_predicted.shape
|
|
|
y_predicted = y_predicted.reshape(len_shape)
|
|
@@ -88,12 +79,22 @@ def train(_data_file, _model_name):
|
|
|
|
|
|
model.save_weights(model_output_path.replace('.json', '.h5'))
|
|
|
|
|
|
- # TODO : Save test score into .csv files
|
|
|
# save score into global_result.csv file
|
|
|
with open(cfg.global_result_filepath, "a") as f:
|
|
|
f.write(_model_name + ';' + str(len(y)) + ';' + str(coeff[0]) + ';\n')
|
|
|
|
|
|
|
|
|
+ # Save plot info using model name
|
|
|
+ plt.figure(figsize=(30, 22))
|
|
|
+ plt.plot(history.history['loss'])
|
|
|
+ plt.plot(history.history['val_loss'])
|
|
|
+ plt.title('model loss', fontsize=20)
|
|
|
+ plt.ylabel('loss', fontsize=16)
|
|
|
+ plt.xlabel('epoch', fontsize=16)
|
|
|
+ plt.legend(['train', 'validation'], loc='upper left', fontsize=16)
|
|
|
+ plt.savefig(model_output_path.replace('.json', '.png'))
|
|
|
+
|
|
|
+
|
|
|
def main():
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train model and saved it")
|