cnn_models.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # main imports
  2. import sys
  3. # model imports
  4. # from keras.preprocessing.image import ImageDataGenerator
  5. from keras.models import Sequential, Model
  6. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  7. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  8. # from keras.applications.vgg19 import VGG19
  9. from keras import backend as K
  10. import tensorflow as tf
  11. # configuration and modules imports
  12. sys.path.insert(0, '') # trick to enable import of main folder module
  13. import custom_config as cfg
  14. #from models import metrics
  15. def generate_model_2D(_input_shape):
  16. model = Sequential()
  17. model.add(Conv2D(140, (3, 3), input_shape=_input_shape))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Conv2D(70, (3, 3)))
  21. model.add(Activation('relu'))
  22. model.add(MaxPooling2D(pool_size=(2, 2)))
  23. model.add(Conv2D(20, (3, 3)))
  24. model.add(Activation('relu'))
  25. model.add(MaxPooling2D(pool_size=(2, 2)))
  26. model.add(Flatten())
  27. model.add(Dense(140))
  28. model.add(Activation('relu'))
  29. model.add(BatchNormalization())
  30. model.add(Dropout(0.5))
  31. # model.add(Dense(120))
  32. # model.add(Activation('sigmoid'))
  33. # model.add(BatchNormalization())
  34. # model.add(Dropout(0.5))
  35. model.add(Dense(80))
  36. model.add(Activation('relu'))
  37. model.add(BatchNormalization())
  38. model.add(Dropout(0.5))
  39. model.add(Dense(40))
  40. model.add(Activation('relu'))
  41. model.add(BatchNormalization())
  42. model.add(Dropout(0.5))
  43. model.add(Dense(20))
  44. model.add(Activation('relu'))
  45. model.add(BatchNormalization())
  46. model.add(Dropout(0.5))
  47. model.add(Dense(2))
  48. model.add(Activation('softmax'))
  49. model.compile(loss='categorical_crossentropy',
  50. optimizer='adam',
  51. #metrics=['accuracy', metrics.auc])
  52. metrics=['accuracy'])
  53. return model
  54. def generate_model_3D(_input_shape):
  55. model = Sequential()
  56. print(_input_shape)
  57. model.add(Conv3D(200, (1, 3, 3), input_shape=_input_shape))
  58. model.add(Activation('relu'))
  59. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  60. model.add(Conv3D(100, (1, 3, 3)))
  61. model.add(Activation('relu'))
  62. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  63. model.add(Conv3D(40, (1, 3, 3)))
  64. model.add(Activation('relu'))
  65. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  66. model.add(Flatten())
  67. model.add(Dense(256))
  68. model.add(BatchNormalization())
  69. model.add(Dropout(0.5))
  70. model.add(Activation('relu'))
  71. model.add(Dense(128))
  72. model.add(BatchNormalization())
  73. model.add(Dropout(0.5))
  74. model.add(Activation('relu'))
  75. model.add(Dense(64))
  76. model.add(BatchNormalization())
  77. model.add(Dropout(0.5))
  78. model.add(Activation('relu'))
  79. model.add(Dense(20))
  80. model.add(BatchNormalization())
  81. model.add(Dropout(0.5))
  82. model.add(Activation('relu'))
  83. model.add(Dense(2))
  84. model.add(Activation('sigmoid'))
  85. model.compile(loss='binary_crossentropy',
  86. optimizer='adam',
  87. #metrics=['accuracy', metrics.auc])
  88. metrics=['accuracy'])
  89. return model
  90. # using transfer learning (VGG19)
  91. '''def generate_model_3D_TL(_input_shape):
  92. # load pre-trained model
  93. model = VGG19(weights='imagenet', include_top=False, input_shape=_input_shape)
  94. # display model layers
  95. model.summary()
  96. # do not train convolutional layers
  97. for layer in model.layers[:5]:
  98. layer.trainable = False
  99. predictions_model = Sequential(model)
  100. predictions_model.add(Flatten(model.output))
  101. predictions_model.add(Dense(1024))
  102. predictions_model.add(Activation('relu'))
  103. predictions_model.add(BatchNormalization())
  104. predictions_model.add(Dropout(0.5))
  105. predictions_model.add(Dense(512))
  106. predictions_model.add(Activation('relu'))
  107. predictions_model.add(BatchNormalization())
  108. predictions_model.add(Dropout(0.5))
  109. predictions_model.add(Dense(256))
  110. predictions_model.add(Activation('relu'))
  111. predictions_model.add(BatchNormalization())
  112. model.add(Dropout(0.5))
  113. predictions_model.add(Dense(100))
  114. predictions_model.add(Activation('relu'))
  115. predictions_model.add(BatchNormalization())
  116. predictions_model.add(Dropout(0.5))
  117. predictions_model.add(Dense(20))
  118. predictions_model.add(Activation('relu'))
  119. predictions_model.add(BatchNormalization())
  120. predictions_model.add(Dropout(0.5))
  121. predictions_model.add(Dense(1))
  122. predictions_model.add(Activation('sigmoid'))
  123. # adding custom Layers
  124. x = model.output
  125. x = Flatten()(x)
  126. x = Dense(1024, activation="relu")(x)
  127. x = BatchNormalization()(x)
  128. x = Dropout(0.5)(x)
  129. x = Dense(256, activation="relu")(x)
  130. x = BatchNormalization()(x)
  131. x = Dropout(0.5)(x)
  132. x = Dense(64, activation="relu")(x)
  133. x = BatchNormalization()(x)
  134. x = Dropout(0.5)(x)
  135. x = Dense(16, activation="relu")(x)
  136. predictions = Dense(1, activation="softmax")(x)
  137. # creating the final model
  138. model_final = Model(input=model.input, output=predictions)
  139. model_final.summary()
  140. model_final.compile(loss='binary_crossentropy',
  141. optimizer='rmsprop',
  142. # metrics=['accuracy', metrics.auc])
  143. metrics=['accuracy'])
  144. return model_final'''
  145. def get_model(n_channels, _input_shape, _tl=False):
  146. # if _tl:
  147. # if n_channels == 3:
  148. # return generate_model_3D_TL(_input_shape)
  149. # else:
  150. # print("Can't use transfer learning with only 1 channel")
  151. if n_channels == 1:
  152. return generate_model_2D(_input_shape)
  153. if n_channels >= 2:
  154. return generate_model_3D(_input_shape)