cnn_models.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # main imports
  2. import sys
  3. # model imports
  4. from keras.preprocessing.image import ImageDataGenerator
  5. from keras.models import Sequential, Model
  6. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  7. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  8. from keras.applications.vgg19 import VGG19
  9. from keras import backend as K
  10. import tensorflow as tf
  11. # configuration and modules imports
  12. sys.path.insert(0, '') # trick to enable import of main folder module
  13. import custom_config as cfg
  14. from models import metrics
  15. def generate_model_2D(_input_shape):
  16. model = Sequential()
  17. model.add(Conv2D(60, (2, 2), input_shape=_input_shape))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Conv2D(40, (2, 2)))
  21. model.add(Activation('relu'))
  22. model.add(MaxPooling2D(pool_size=(2, 2)))
  23. model.add(Conv2D(20, (2, 2)))
  24. model.add(Activation('relu'))
  25. model.add(MaxPooling2D(pool_size=(2, 2)))
  26. model.add(Flatten())
  27. model.add(Dense(140))
  28. model.add(Activation('relu'))
  29. model.add(BatchNormalization())
  30. model.add(Dropout(0.5))
  31. model.add(Dense(120))
  32. model.add(Activation('relu'))
  33. model.add(BatchNormalization())
  34. model.add(Dropout(0.5))
  35. model.add(Dense(80))
  36. model.add(Activation('relu'))
  37. model.add(BatchNormalization())
  38. model.add(Dropout(0.5))
  39. model.add(Dense(40))
  40. model.add(Activation('relu'))
  41. model.add(BatchNormalization())
  42. model.add(Dropout(0.5))
  43. model.add(Dense(20))
  44. model.add(Activation('relu'))
  45. model.add(BatchNormalization())
  46. model.add(Dropout(0.5))
  47. model.add(Dense(1))
  48. model.add(Activation('sigmoid'))
  49. model.compile(loss='binary_crossentropy',
  50. optimizer='rmsprop',
  51. metrics=['accuracy', metrics.auc])
  52. return model
  53. def generate_model_3D(_input_shape):
  54. model = Sequential()
  55. print(_input_shape)
  56. model.add(Conv3D(60, (1, 2, 2), input_shape=_input_shape))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  59. model.add(Conv3D(40, (1, 2, 2)))
  60. model.add(Activation('relu'))
  61. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  62. model.add(Conv3D(20, (1, 2, 2)))
  63. model.add(Activation('relu'))
  64. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  65. model.add(Flatten())
  66. model.add(Dense(140))
  67. model.add(Activation('relu'))
  68. model.add(BatchNormalization())
  69. model.add(Dropout(0.5))
  70. model.add(Dense(120))
  71. model.add(Activation('relu'))
  72. model.add(BatchNormalization())
  73. model.add(Dropout(0.5))
  74. model.add(Dense(80))
  75. model.add(Activation('relu'))
  76. model.add(BatchNormalization())
  77. model.add(Dropout(0.5))
  78. model.add(Dense(40))
  79. model.add(Activation('relu'))
  80. model.add(BatchNormalization())
  81. model.add(Dropout(0.5))
  82. model.add(Dense(20))
  83. model.add(Activation('relu'))
  84. model.add(BatchNormalization())
  85. model.add(Dropout(0.5))
  86. model.add(Dense(1))
  87. model.add(Activation('sigmoid'))
  88. model.compile(loss='binary_crossentropy',
  89. optimizer='rmsprop',
  90. metrics=['accuracy', metrics.auc])
  91. return model
  92. # using transfer learning (VGG19)
  93. def generate_model_3D_TL(_input_shape):
  94. # load pre-trained model
  95. model = VGG19(weights='imagenet', include_top=False, input_shape=_input_shape)
  96. # display model layers
  97. model.summary()
  98. # do not train convolutional layers
  99. for layer in model.layers[:5]:
  100. layer.trainable = False
  101. '''predictions_model = Sequential(model)
  102. predictions_model.add(Flatten(model.output))
  103. predictions_model.add(Dense(1024))
  104. predictions_model.add(Activation('relu'))
  105. predictions_model.add(BatchNormalization())
  106. predictions_model.add(Dropout(0.5))
  107. predictions_model.add(Dense(512))
  108. predictions_model.add(Activation('relu'))
  109. predictions_model.add(BatchNormalization())
  110. predictions_model.add(Dropout(0.5))
  111. predictions_model.add(Dense(256))
  112. predictions_model.add(Activation('relu'))
  113. predictions_model.add(BatchNormalization())
  114. model.add(Dropout(0.5))
  115. predictions_model.add(Dense(100))
  116. predictions_model.add(Activation('relu'))
  117. predictions_model.add(BatchNormalization())
  118. predictions_model.add(Dropout(0.5))
  119. predictions_model.add(Dense(20))
  120. predictions_model.add(Activation('relu'))
  121. predictions_model.add(BatchNormalization())
  122. predictions_model.add(Dropout(0.5))
  123. predictions_model.add(Dense(1))
  124. predictions_model.add(Activation('sigmoid'))'''
  125. # adding custom Layers
  126. x = model.output
  127. x = Flatten()(x)
  128. x = Dense(1024, activation="relu")(x)
  129. x = BatchNormalization()(x)
  130. x = Dropout(0.5)(x)
  131. x = Dense(256, activation="relu")(x)
  132. x = BatchNormalization()(x)
  133. x = Dropout(0.5)(x)
  134. x = Dense(64, activation="relu")(x)
  135. x = BatchNormalization()(x)
  136. x = Dropout(0.5)(x)
  137. x = Dense(16, activation="relu")(x)
  138. predictions = Dense(1, activation="softmax")(x)
  139. # creating the final model
  140. model_final = Model(input=model.input, output=predictions)
  141. model_final.summary()
  142. model_final.compile(loss='binary_crossentropy',
  143. optimizer='rmsprop',
  144. metrics=['accuracy', metrics.auc])
  145. return model_final
  146. def get_model(n_channels, _input_shape, tl=False):
  147. if tl:
  148. if n_channels == 3:
  149. return generate_model_3D_TL(_input_shape)
  150. else:
  151. print("Can't use transfer learning with only 1 channel")
  152. if n_channels == 1:
  153. return generate_model_2D(_input_shape)
  154. if n_channels == 3:
  155. return generate_model_3D(_input_shape)