cnn_models.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # model imports
  2. from keras.preprocessing.image import ImageDataGenerator
  3. from keras.models import Sequential, Model
  4. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  5. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  6. from keras.applications.vgg19 import VGG19
  7. from keras import backend as K
  8. import tensorflow as tf
  9. # configuration imports
  10. from . import metrics
  11. from ..config import cnn_config as cfg
  12. def generate_model_2D(_input_shape):
  13. model = Sequential()
  14. model.add(Conv2D(60, (2, 2), input_shape=_input_shape))
  15. model.add(Activation('relu'))
  16. model.add(MaxPooling2D(pool_size=(2, 2)))
  17. model.add(Conv2D(40, (2, 2)))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Conv2D(20, (2, 2)))
  21. model.add(Activation('relu'))
  22. model.add(MaxPooling2D(pool_size=(2, 2)))
  23. model.add(Flatten())
  24. model.add(Dense(140))
  25. model.add(Activation('relu'))
  26. model.add(BatchNormalization())
  27. model.add(Dropout(0.5))
  28. model.add(Dense(120))
  29. model.add(Activation('relu'))
  30. model.add(BatchNormalization())
  31. model.add(Dropout(0.5))
  32. model.add(Dense(80))
  33. model.add(Activation('relu'))
  34. model.add(BatchNormalization())
  35. model.add(Dropout(0.5))
  36. model.add(Dense(40))
  37. model.add(Activation('relu'))
  38. model.add(BatchNormalization())
  39. model.add(Dropout(0.5))
  40. model.add(Dense(20))
  41. model.add(Activation('relu'))
  42. model.add(BatchNormalization())
  43. model.add(Dropout(0.5))
  44. model.add(Dense(1))
  45. model.add(Activation('sigmoid'))
  46. model.compile(loss='binary_crossentropy',
  47. optimizer='rmsprop',
  48. metrics=['accuracy', metrics.auc])
  49. return model
  50. def generate_model_3D(_input_shape):
  51. model = Sequential()
  52. print(_input_shape)
  53. model.add(Conv3D(60, (1, 2, 2), input_shape=_input_shape))
  54. model.add(Activation('relu'))
  55. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  56. model.add(Conv3D(40, (1, 2, 2)))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  59. model.add(Conv3D(20, (1, 2, 2)))
  60. model.add(Activation('relu'))
  61. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  62. model.add(Flatten())
  63. model.add(Dense(140))
  64. model.add(Activation('relu'))
  65. model.add(BatchNormalization())
  66. model.add(Dropout(0.5))
  67. model.add(Dense(120))
  68. model.add(Activation('relu'))
  69. model.add(BatchNormalization())
  70. model.add(Dropout(0.5))
  71. model.add(Dense(80))
  72. model.add(Activation('relu'))
  73. model.add(BatchNormalization())
  74. model.add(Dropout(0.5))
  75. model.add(Dense(40))
  76. model.add(Activation('relu'))
  77. model.add(BatchNormalization())
  78. model.add(Dropout(0.5))
  79. model.add(Dense(20))
  80. model.add(Activation('relu'))
  81. model.add(BatchNormalization())
  82. model.add(Dropout(0.5))
  83. model.add(Dense(1))
  84. model.add(Activation('sigmoid'))
  85. model.compile(loss='binary_crossentropy',
  86. optimizer='rmsprop',
  87. metrics=['accuracy', metrics.auc])
  88. return model
  89. # using transfer learning (VGG19)
  90. def generate_model_3D_TL(_input_shape):
  91. # load pre-trained model
  92. model = VGG19(weights='imagenet', include_top=False, input_shape=_input_shape)
  93. # display model layers
  94. model.summary()
  95. # do not train convolutional layers
  96. for layer in model.layers[:5]:
  97. layer.trainable = False
  98. predictions_model = Sequential(model)
  99. #Adding custom Layers
  100. '''predictions_model.add(Flatten(model.output))
  101. predictions_model.add(Dense(1024))
  102. predictions_model.add(Activation('relu'))
  103. predictions_model.add(BatchNormalization())
  104. predictions_model.add(Dropout(0.5))
  105. predictions_model.add(Dense(512))
  106. predictions_model.add(Activation('relu'))
  107. predictions_model.add(BatchNormalization())
  108. predictions_model.add(Dropout(0.5))
  109. predictions_model.add(Dense(256))
  110. predictions_model.add(Activation('relu'))
  111. predictions_model.add(BatchNormalization())
  112. model.add(Dropout(0.5))
  113. predictions_model.add(Dense(100))
  114. predictions_model.add(Activation('relu'))
  115. predictions_model.add(BatchNormalization())
  116. predictions_model.add(Dropout(0.5))
  117. predictions_model.add(Dense(20))
  118. predictions_model.add(Activation('relu'))
  119. predictions_model.add(BatchNormalization())
  120. predictions_model.add(Dropout(0.5))
  121. predictions_model.add(Dense(1))
  122. predictions_model.add(Activation('sigmoid'))'''
  123. # adding custom Layers
  124. x = model.output
  125. x = Flatten()(x)
  126. x = Dense(1024, activation="relu")(x)
  127. x = BatchNormalization()(x)
  128. x = Dropout(0.5)(x)
  129. x = Dense(256, activation="relu")(x)
  130. x = BatchNormalization()(x)
  131. x = Dropout(0.5)(x)
  132. x = Dense(64, activation="relu")(x)
  133. x = BatchNormalization()(x)
  134. x = Dropout(0.5)(x)
  135. x = Dense(16, activation="relu")(x)
  136. predictions = Dense(1, activation="softmax")(x)
  137. # creating the final model
  138. model_final = Model(input=model.input, output=predictions)
  139. model_final.summary()
  140. model_final.compile(loss='binary_crossentropy',
  141. optimizer='rmsprop',
  142. metrics=['accuracy', metrics.auc])
  143. return model_final
  144. def get_model(n_channels, _input_shape, tl=False):
  145. if tl:
  146. if n_channels == 3:
  147. return generate_model_3D_TL(_input_shape)
  148. else:
  149. print("Can't use transfer learning with only 1 channel")
  150. if n_channels == 1:
  151. return generate_model_2D(_input_shape)
  152. if n_channels == 3:
  153. return generate_model_3D(_input_shape)