Parcourir la source

list of prediction

Jérôme BUISINE il y a 3 ans
Parent
commit
91339204bd
1 fichiers modifiés avec 3 ajouts et 3 suppressions
  1. 3 3
      train_model.py

+ 3 - 3
train_model.py

@@ -282,9 +282,9 @@ def main():
     model.save(model_output_path)
 
     # Get results obtained from model
-    y_train_prediction = model.predict(tf.convert_to_tensor(np.asarray(X_train)))
-    y_val_prediction = model.predict(tf.convert_to_tensor(np.asarray(X_val)))
-    y_test_prediction = model.predict(tf.convert_to_tensor(np.asarray(x_dataset_test)))
+    y_train_prediction = model.predict(list(X_train))
+    y_val_prediction = model.predict(list(X_val))
+    y_test_prediction = model.predict(list(x_dataset_test))
 
     y_train_prediction = np.argmax(y_train_prediction, axis=1)
     y_val_prediction = np.argmax(y_val_prediction, axis=1)