Parcourir la source

change predict input type

Jérôme BUISINE il y a 3 ans
Parent
commit
25b939efe3
1 fichiers modifiés avec 6 ajouts et 4 suppressions
  1. 6 4
      train_model.py

+ 6 - 4
train_model.py

@@ -238,7 +238,8 @@ def main():
 
     if len(backups) > 0:
         last_backup_file = backups[-1]
-        model = load_model(last_backup_file)
+        last_backup_file_path = os.path.join(model_backup_folder, last_backup_file)
+        model = load_model(last_backup_file_path)
 
         # get initial epoch
         initial_epoch = int(last_backup_file.split('_')[-1].replace('.h5', ''))
@@ -281,12 +282,13 @@ def main():
     model.save(model_output_path)
 
     # Get results obtained from model
-    y_train_prediction = model.predict(X_train)
-    y_val_prediction = model.predict(X_val)
-    y_test_prediction = model.predict(x_dataset_test)
+    y_train_prediction = model.predict(np.array(X_train))
+    y_val_prediction = model.predict(np.array(X_val))
+    y_test_prediction = model.predict(np.array(x_dataset_test))
 
     y_train_prediction = np.argmax(y_train_prediction, axis=1)
     y_val_prediction = np.argmax(y_val_prediction, axis=1)
+    y_test_prediction = np.argmax(y_test_prediction, axis=1)
 
     acc_train_score = accuracy_score(y_train, y_train_prediction)
     acc_val_score = accuracy_score(y_val, y_val_prediction)