Parcourir la source

add of dropout for LSTM 2D layers

Jérôme BUISINE il y a 4 ans
Parent
commit
df4acbba4a
1 fichiers modifiés avec 7 ajouts et 2 suppressions
  1. 7 2
      train_lstm_weighted.py

+ 7 - 2
train_lstm_weighted.py

@@ -103,17 +103,22 @@ def create_model(_input_shape):
 
     model.add(ConvLSTM2D(filters=100, kernel_size=(3, 3),
                    input_shape=_input_shape,
+                   dropout=0.4,
+                   recurrent_dropout=0.4,
                    padding='same', return_sequences=True))
     model.add(BatchNormalization())
-    model.add(Dropout(0.4))
 
     model.add(ConvLSTM2D(filters=50, kernel_size=(3, 3),
+                    dropout=0.4,
+                    recurrent_dropout=0.4,
                     padding='same', return_sequences=True))
     model.add(BatchNormalization())
     model.add(Dropout(0.4))
 
     model.add(Conv3D(filters=20, kernel_size=(3, 3, 3),
                 activation='sigmoid',
+                dropout=0.4,
+                recurrent_dropout=0.4,
                 padding='same', data_format='channels_last'))
     model.add(Dropout(0.4))
 
@@ -200,7 +205,7 @@ def main():
     model.summary()
 
     print("Fitting model with custom class_weight", class_weight)
-    history = model.fit(X_train, y_train, batch_size=64, epochs=30, validation_split = 0.30, verbose=1, shuffle=True, class_weight=class_weight)
+    history = model.fit(X_train, y_train, batch_size=64, epochs=1, validation_split = 0.30, verbose=1, shuffle=True, class_weight=class_weight)
 
     # list all data in history
     # print(history.history.keys())