瀏覽代碼

update model params

Jérôme BUISINE 4 年之前
父節點
當前提交
a9fa043e7c
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      train_lstm_weighted.py

+ 1 - 1
train_lstm_weighted.py

@@ -200,7 +200,7 @@ def main():
     model.summary()
 
     print("Fitting model with custom class_weight", class_weight)
-    history = model.fit(X_train, y_train, batch_size=16, epochs=3, validation_split = 0.30, verbose=1, shuffle=True, class_weight=class_weight)
+    history = model.fit(X_train, y_train, batch_size=64, epochs=30, validation_split = 0.30, verbose=1, shuffle=True, class_weight=class_weight)
 
     # list all data in history
     # print(history.history.keys())