浏览代码

Merge branch 'release/v0.1.7'

Jérôme BUISINE 4 年之前
父节点
当前提交
ae52ec5b34
共有 1 个文件被更改,包括 84 次插入2 次删除
  1. 84 2
      models/cnn_models.py

+ 84 - 2
models/cnn_models.py

@@ -1,10 +1,13 @@
+# model imports
 from keras.preprocessing.image import ImageDataGenerator
-from keras.models import Sequential
+from keras.models import Sequential, Model
 from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
 from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
+from keras.applications.vgg19 import VGG19
 from keras import backend as K
 import tensorflow as tf
 
+# configuration imports
 from . import metrics
 from ..config import cnn_config as cfg
 
@@ -60,6 +63,7 @@ def generate_model_2D(_input_shape):
 
     return model
 
+
 def generate_model_3D(_input_shape):
 
     model = Sequential()
@@ -115,7 +119,85 @@ def generate_model_3D(_input_shape):
     return model
 
 
-def get_model(n_channels, _input_shape):
+# using transfer learning (VGG19)
+def generate_model_3D_TL(_input_shape):
+
+    # load pre-trained model
+    model = VGG19(weights='imagenet', include_top=False, input_shape=_input_shape)
+    # display model layers
+    model.summary()
+
+    # do not train convolutional layers
+    for layer in model.layers[:5]:
+        layer.trainable = False
+
+    predictions_model = Sequential(model)
+
+    #Adding custom Layers
+    '''predictions_model.add(Flatten(model.output))
+
+    predictions_model.add(Dense(1024))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(512))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(256))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(100))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(20))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(1))
+    predictions_model.add(Activation('sigmoid'))'''
+
+    # adding custom Layers 
+    x = model.output
+    x = Flatten()(x)
+    x = Dense(1024, activation="relu")(x)
+    x = BatchNormalization()(x)
+    x = Dropout(0.5)(x)
+    x = Dense(256, activation="relu")(x)
+    x = BatchNormalization()(x)
+    x = Dropout(0.5)(x)
+    x = Dense(64, activation="relu")(x)
+    x = BatchNormalization()(x)
+    x = Dropout(0.5)(x)
+    x = Dense(16, activation="relu")(x)
+    predictions = Dense(1, activation="softmax")(x)
+
+    # creating the final model 
+    model_final = Model(input=model.input, output=predictions)
+
+    model_final.summary()
+
+    model_final.compile(loss='binary_crossentropy',
+                  optimizer='rmsprop',
+                  metrics=['accuracy', metrics.auc])
+
+    return model_final
+
+
+def get_model(n_channels, _input_shape, tl=False):
+    
+    if tl:
+        if n_channels == 3:
+            return generate_model_3D_TL(_input_shape)
+        else:
+            print("Can't use transfer learning with only 1 channel")
 
     if n_channels == 1:
         return generate_model_2D(_input_shape)