Parcourir la source

cnn 3D model with less params

Jérôme BUISINE il y a 3 ans
Parent
commit
ec54aa8403
1 fichiers modifiés avec 14 ajouts et 11 suppressions
  1. 14 11
      cnn_models.py

+ 14 - 11
cnn_models.py

@@ -6,6 +6,7 @@ import sys
 from keras.models import Sequential, Model
 from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
 from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
+from tensorflow.keras import regularizers
 # from keras.applications.vgg19 import VGG19
 from keras import backend as K
 import tensorflow as tf
@@ -77,36 +78,38 @@ def generate_model_3D(_input_shape):
 
     print(_input_shape)
 
-    model.add(Conv3D(200, (1, 3, 3), input_shape=_input_shape))
+    model.add(Conv3D(60, (1, 3, 3), input_shape=_input_shape))
     model.add(Activation('relu'))
     model.add(MaxPooling3D(pool_size=(1, 2, 2)))
 
-    model.add(Conv3D(100, (1, 3, 3)))
+    model.add(Conv3D(40, (1, 3, 3)))
     model.add(Activation('relu'))
     model.add(MaxPooling3D(pool_size=(1, 2, 2)))
 
-    model.add(Conv3D(40, (1, 3, 3)))
+    model.add(Conv3D(20, (1, 3, 3)))
     model.add(Activation('relu'))
     model.add(MaxPooling3D(pool_size=(1, 2, 2)))
 
     model.add(Flatten())
 
-    model.add(Dense(256))
     model.add(BatchNormalization())
     model.add(Dropout(0.5))
     model.add(Activation('relu'))
 
-    model.add(Dense(128))
+    model.add(Dense(64, 
+        kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4),
+        bias_regularizer=regularizers.l2(1e-4),
+        activity_regularizer=regularizers.l2(1e-5)))
+        
     model.add(BatchNormalization())
     model.add(Dropout(0.5))
     model.add(Activation('relu'))
 
-    model.add(Dense(64))
-    model.add(BatchNormalization())
-    model.add(Dropout(0.5))
-    model.add(Activation('relu'))
-
-    model.add(Dense(20))
+    model.add(Dense(20, 
+        kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4),
+        bias_regularizer=regularizers.l2(1e-4),
+        activity_regularizer=regularizers.l2(1e-5)))
+        
     model.add(BatchNormalization())
     model.add(Dropout(0.5))
     model.add(Activation('relu'))