Parcourir la source

update model params

Jérôme BUISINE il y a 3 ans
Parent
commit
a9fa043e7c
1 fichiers modifiés avec 1 ajouts et 1 suppressions
  1. 1 1
      train_lstm_weighted.py

+ 1 - 1
train_lstm_weighted.py

@@ -200,7 +200,7 @@ def main():
     model.summary()
 
     print("Fitting model with custom class_weight", class_weight)
-    history = model.fit(X_train, y_train, batch_size=16, epochs=3, validation_split = 0.30, verbose=1, shuffle=True, class_weight=class_weight)
+    history = model.fit(X_train, y_train, batch_size=64, epochs=30, validation_split = 0.30, verbose=1, shuffle=True, class_weight=class_weight)
 
     # list all data in history
     # print(history.history.keys())