Parcourir la source

update lstm model architecture

Jérôme BUISINE il y a 3 ans
Parent
commit
9f445fa2a5
1 fichiers modifiés avec 9 ajouts et 9 suppressions
  1. 9 9
      train_lstm_weighted.py

+ 9 - 9
train_lstm_weighted.py

@@ -103,30 +103,30 @@ def create_model(_input_shape):
 
     model.add(ConvLSTM2D(filters=100, kernel_size=(3, 3),
                    input_shape=_input_shape,
-                   dropout=0.5,
-                   recurrent_dropout=0.5,
+                   dropout=0.4,
+                   #recurrent_dropout=0.5,
                    padding='same', return_sequences=True))
     model.add(BatchNormalization())
 
     model.add(ConvLSTM2D(filters=50, kernel_size=(3, 3),
-                    dropout=0.5,
-                    recurrent_dropout=0.5,
+                    dropout=0.4,
+                    #recurrent_dropout=0.5,
                     padding='same', return_sequences=True))
     model.add(BatchNormalization())
-    model.add(Dropout(0.5))
+    model.add(Dropout(0.4))
 
     model.add(Conv3D(filters=20, kernel_size=(3, 3, 3),
                 activation='sigmoid',
                 padding='same', data_format='channels_last'))
-    model.add(Dropout(0.5))
+    model.add(Dropout(0.4))
 
     model.add(Flatten())
     model.add(Dense(512, activation='sigmoid'))
-    model.add(Dropout(0.5))
+    model.add(Dropout(0.4))
     model.add(Dense(128, activation='sigmoid'))
-    model.add(Dropout(0.5))
+    model.add(Dropout(0.4))
     model.add(Dense(1, activation='sigmoid'))
-    model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])
+    model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
     print ('Compiling...')
     # model.compile(loss='binary_crossentropy',