Parcourir la source

Update of pure CNN models

jbuisine il y a 5 ans
Parent
commit
cb31d045d6
2 fichiers modifiés avec 37 ajouts et 10 suppressions
  1. 22 4
      classification_cnn_keras.py
  2. 15 6
      classification_cnn_keras_cross_validation.py

+ 22 - 4
classification_cnn_keras.py

@@ -29,7 +29,7 @@ import json
 from keras.preprocessing.image import ImageDataGenerator
 from keras.models import Sequential
 from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
-from keras.layers import Activation, Dropout, Flatten, Dense
+from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
 from keras import backend as K
 from keras.utils import plot_model
 
@@ -77,12 +77,30 @@ def generate_model():
     model.add(MaxPooling2D(pool_size=(2, 2)))
 
     model.add(Flatten())
-    model.add(Dense(60))
+
+    model.add(Dense(140))
+    model.add(Activation('relu'))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.3))
+
+    model.add(Dense(120))
+    model.add(Activation('relu'))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.3))
+
+    model.add(Dense(80))
     model.add(Activation('relu'))
-    model.add(Dropout(0.4))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.2))
+
+    model.add(Dense(40))
+    model.add(Activation('relu'))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.2))
 
-    model.add(Dense(30))
+    model.add(Dense(20))
     model.add(Activation('relu'))
+    model.add(BatchNormalization())
     model.add(Dropout(0.2))
 
     model.add(Dense(1))

+ 15 - 6
classification_cnn_keras_cross_validation.py

@@ -80,21 +80,30 @@ def generate_model():
 
     model.add(Flatten())
 
-    model.add(Dense(256))
+    model.add(Dense(140))
     model.add(Activation('relu'))
-    model.add(Dropout(0.2))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.3))
+
+    model.add(Dense(120))
+    model.add(Activation('relu'))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.3))
 
-    model.add(Dense(128))
+    model.add(Dense(80))
     model.add(Activation('relu'))
+    model.add(BatchNormalization())
     model.add(Dropout(0.2))
 
-    model.add(Dense(64))
+    model.add(Dense(40))
     model.add(Activation('relu'))
+    model.add(BatchNormalization())
     model.add(Dropout(0.2))
 
-    model.add(Dense(32))
+    model.add(Dense(20))
     model.add(Activation('relu'))
-    model.add(Dropout(0.05))
+    model.add(BatchNormalization())
+    model.add(Dropout(0.2))
 
     model.add(Dense(1))
     model.add(Activation('sigmoid'))