Parcourir la source

model prediction updated for LSTM

Jérôme BUISINE il y a 3 ans
Parent
commit
153f06bb97
1 fichiers modifiés avec 6 ajouts et 5 suppressions
  1. 6 5
      train_lstm_weighted.py

+ 6 - 5
train_lstm_weighted.py

@@ -315,12 +315,13 @@ def main():
     # train_score, train_acc = model.evaluate(X_train, y_train, batch_size=1)
 
     # print(train_acc)
-    y_train_predict = model.predict_classes(X_train)
-    y_val_predict = model.predict_classes(X_val)
-    y_test_predict = model.predict_classes(X_test)
+    y_train_predict = model.predict(X_train, batch_size=1, verbose=1)
+    y_val_predict = model.predict(X_val, batch_size=1, verbose=1)
+    y_test_predict = model.predict(X_test, batch_size=1, verbose=1)
 
-    print(y_train_predict)
-    print(y_test_predict)
+    y_train_predict = [ 1 if l > 0.5 else 0 for l in y_train_predict ]
+    y_val_predict = [ 1 if l > 0.5 else 0 for l in y_val_predict ]
+    y_test_predict = [ 1 if l > 0.5 else 0 for l in y_test_predict ]
 
     auc_train = roc_auc_score(y_train, y_train_predict)
     auc_val = roc_auc_score(y_val, y_val_predict)