Parcourir la source

fix issue when using metric

Jérôme BUISINE il y a 3 ans
Parent
commit
82da620360
1 fichiers modifiés avec 7 ajouts et 7 suppressions
  1. 7 7
      train_model.py

+ 7 - 7
train_model.py

@@ -255,15 +255,15 @@ def main():
     # prepare train and validation dataset
     X_train, X_val, y_train, y_val = train_test_split(x_data_train, y_dataset_train, test_size=p_val_size, shuffle=False)
 
-    y_train = to_categorical(y_train)
-    y_val = to_categorical(y_val)
-    y_test = to_categorical(y_dataset_test)
+    y_train_cat = to_categorical(y_train)
+    y_val_cat = to_categorical(y_val)
+    y_test_cat = to_categorical(y_dataset_test)
 
     print('-----------------------------')
     print("-- Fitting model with custom class_weight", class_weight)
     print('-----------------------------')
-    model.fit(X_train, y_train, 
-        validation_data=(X_val, y_val), 
+    model.fit(X_train, y_train_cat, 
+        validation_data=(X_val, y_val_cat), 
         initial_epoch=initial_epoch, 
         epochs=p_epochs, 
         batch_size=p_batch_size, 
@@ -293,11 +293,11 @@ def main():
 
     acc_train_score = accuracy_score(y_train, y_train_prediction)
     acc_val_score = accuracy_score(y_val, y_val_prediction)
-    acc_test_score = accuracy_score(y_test, y_test_prediction)
+    acc_test_score = accuracy_score(y_dataset_test, y_test_prediction)
 
     roc_train_score = roc_auc_score(y_train, y_train_prediction)
     roc_val_score = roc_auc_score(y_val, y_val_prediction)
-    roc_test_score = roc_auc_score(y_test, y_val_prediction)
+    roc_test_score = roc_auc_score(y_dataset_test, y_val_prediction)
 
     # save model performance
     if not os.path.exists(cfg.output_results_folder):