Ver código fonte

Merge branch 'release/v0.1.7'

Jérôme BUISINE 5 anos atrás
pai
commit
486200b982
1 arquivos alterados com 84 adições e 2 exclusões
  1. 84 2
      models/cnn_models.py

+ 84 - 2
models/cnn_models.py

@@ -1,10 +1,13 @@
+# model imports
 from keras.preprocessing.image import ImageDataGenerator
 from keras.preprocessing.image import ImageDataGenerator
-from keras.models import Sequential
+from keras.models import Sequential, Model
 from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
 from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
 from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
 from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
+from keras.applications.vgg19 import VGG19
 from keras import backend as K
 from keras import backend as K
 import tensorflow as tf
 import tensorflow as tf
 
 
+# configuration imports
 from . import metrics
 from . import metrics
 from ..config import cnn_config as cfg
 from ..config import cnn_config as cfg
 
 
@@ -60,6 +63,7 @@ def generate_model_2D(_input_shape):
 
 
     return model
     return model
 
 
+
 def generate_model_3D(_input_shape):
 def generate_model_3D(_input_shape):
 
 
     model = Sequential()
     model = Sequential()
@@ -115,7 +119,85 @@ def generate_model_3D(_input_shape):
     return model
     return model
 
 
 
 
-def get_model(n_channels, _input_shape):
+# using transfer learning (VGG19)
+def generate_model_3D_TL(_input_shape):
+
+    # load pre-trained model
+    model = VGG19(weights='imagenet', include_top=False, input_shape=_input_shape)
+    # display model layers
+    model.summary()
+
+    # do not train convolutional layers
+    for layer in model.layers[:5]:
+        layer.trainable = False
+
+    predictions_model = Sequential(model)
+
+    #Adding custom Layers
+    '''predictions_model.add(Flatten(model.output))
+
+    predictions_model.add(Dense(1024))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(512))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(256))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(100))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(20))
+    predictions_model.add(Activation('relu'))
+    predictions_model.add(BatchNormalization())
+    predictions_model.add(Dropout(0.5))
+
+    predictions_model.add(Dense(1))
+    predictions_model.add(Activation('sigmoid'))'''
+
+    # adding custom Layers 
+    x = model.output
+    x = Flatten()(x)
+    x = Dense(1024, activation="relu")(x)
+    x = BatchNormalization()(x)
+    x = Dropout(0.5)(x)
+    x = Dense(256, activation="relu")(x)
+    x = BatchNormalization()(x)
+    x = Dropout(0.5)(x)
+    x = Dense(64, activation="relu")(x)
+    x = BatchNormalization()(x)
+    x = Dropout(0.5)(x)
+    x = Dense(16, activation="relu")(x)
+    predictions = Dense(1, activation="softmax")(x)
+
+    # creating the final model 
+    model_final = Model(input=model.input, output=predictions)
+
+    model_final.summary()
+
+    model_final.compile(loss='binary_crossentropy',
+                  optimizer='rmsprop',
+                  metrics=['accuracy', metrics.auc])
+
+    return model_final
+
+
+def get_model(n_channels, _input_shape, tl=False):
+    
+    if tl:
+        if n_channels == 3:
+            return generate_model_3D_TL(_input_shape)
+        else:
+            print("Can't use transfer learning with only 1 channel")
 
 
     if n_channels == 1:
     if n_channels == 1:
         return generate_model_2D(_input_shape)
         return generate_model_2D(_input_shape)