Parcourir la source

Load test dataset later

Jérôme BUISINE il y a 3 ans
Parent
commit
5363b58ded
1 fichiers modifiés avec 14 ajouts et 9 suppressions
  1. 14 9
      train_lstm_weighted.py

+ 14 - 9
train_lstm_weighted.py

@@ -248,10 +248,6 @@ def main():
     X_train_all = build_input(X_train_all, p_seq_norm, p_chanels)
     y_train_all = final_df_train.loc[:, 0].astype('int')
 
-    X_test = final_df_test.loc[:, 1:].apply(lambda x: x.astype(str).str.split('::'))
-    X_test = build_input(X_test, p_seq_norm, p_chanels)
-    y_test = final_df_test.loc[:, 0].astype('int')
-
     input_shape = (X_train_all.shape[1], X_train_all.shape[2], X_train_all.shape[3], X_train_all.shape[4])
     
     
@@ -269,7 +265,6 @@ def main():
     checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=0, mode='max')
     callbacks_list = [checkpoint]
 
-    
     # check if backup already exists
     backups = sorted(os.listdir(model_backup_folder))
 
@@ -318,20 +313,30 @@ def main():
     # print(train_acc)
     y_train_predict = model.predict(X_train, batch_size=1, verbose=1)
     y_val_predict = model.predict(X_val, batch_size=1, verbose=1)
-    y_test_predict = model.predict(X_test, batch_size=1, verbose=1)
 
     y_train_predict = [ 1 if l > 0.5 else 0 for l in y_train_predict ]
     y_val_predict = [ 1 if l > 0.5 else 0 for l in y_val_predict ]
-    y_test_predict = [ 1 if l > 0.5 else 0 for l in y_test_predict ]
 
     auc_train = roc_auc_score(y_train, y_train_predict)
     auc_val = roc_auc_score(y_val, y_val_predict)
-    auc_test = roc_auc_score(y_test, y_test_predict)
 
     acc_train = accuracy_score(y_train, y_train_predict)
     acc_val = accuracy_score(y_val, y_val_predict)
-    acc_test = accuracy_score(y_test, y_test_predict)
+
+    # remove unused variables
+    del X_train
+    del y_train
     
+    X_test = final_df_test.loc[:, 1:].apply(lambda x: x.astype(str).str.split('::'))
+    X_test = build_input(X_test, p_seq_norm, p_chanels)
+    y_test = final_df_test.loc[:, 0].astype('int')
+
+    y_test_predict = model.predict(X_test, batch_size=1, verbose=1)
+    y_test_predict = [ 1 if l > 0.5 else 0 for l in y_test_predict ]
+
+    acc_test = accuracy_score(y_test, y_test_predict)
+    auc_test = roc_auc_score(y_test, y_test_predict)
+
     print('Train ACC:', acc_train)
     print('Train AUC', auc_train)
     print('Val ACC:', acc_val)