浏览代码

predict test input update

Jérôme BUISINE 3 年之前
父节点
当前提交
74a168da04
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. 4 4
      train_model.py

+ 4 - 4
train_model.py

@@ -281,11 +281,11 @@ def main():
     model_output_path = os.path.join(cfg.output_models, p_output + '.h5')
     model.save(model_output_path)
 
-    print('Input prediction shape', X_train.shape)
+    print('Begin of prediction score on the whole dataset:')
     # Get results obtained from model
-    y_train_prediction = model.predict(list(X_train))
-    y_val_prediction = model.predict(list(X_val))
-    y_test_prediction = model.predict(list(x_dataset_test))
+    y_train_prediction = model.predict(X_train, verbose=1)
+    y_val_prediction = model.predict(X_val, verbose=1)
+    y_test_prediction = model.predict(x_data_test, verbose=1)
 
     y_train_prediction = np.argmax(y_train_prediction, axis=1)
     y_val_prediction = np.argmax(y_val_prediction, axis=1)