cnn_models.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # main imports
  2. import sys
  3. # model imports
  4. from keras.preprocessing.image import ImageDataGenerator
  5. from keras.models import Sequential, Model
  6. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  7. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  8. from keras.applications.vgg19 import VGG19
  9. from keras import backend as K
  10. import tensorflow as tf
  11. # configuration and modules imports
  12. sys.path.insert(0, '') # trick to enable import of main folder module
  13. import custom_config as cfg
  14. #from models import metrics
  15. def generate_model_2D(_input_shape):
  16. model = Sequential()
  17. model.add(Conv2D(140, (3, 3), input_shape=_input_shape))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Conv2D(70, (3, 3)))
  21. model.add(Activation('relu'))
  22. model.add(MaxPooling2D(pool_size=(2, 2)))
  23. model.add(Conv2D(20, (3, 3)))
  24. model.add(Activation('relu'))
  25. model.add(MaxPooling2D(pool_size=(2, 2)))
  26. model.add(Flatten())
  27. model.add(Dense(140))
  28. model.add(Activation('relu'))
  29. model.add(BatchNormalization())
  30. model.add(Dropout(0.5))
  31. # model.add(Dense(120))
  32. # model.add(Activation('sigmoid'))
  33. # model.add(BatchNormalization())
  34. # model.add(Dropout(0.5))
  35. model.add(Dense(80))
  36. model.add(Activation('relu'))
  37. model.add(BatchNormalization())
  38. model.add(Dropout(0.5))
  39. model.add(Dense(40))
  40. model.add(Activation('relu'))
  41. model.add(BatchNormalization())
  42. model.add(Dropout(0.5))
  43. model.add(Dense(20))
  44. model.add(Activation('relu'))
  45. model.add(BatchNormalization())
  46. model.add(Dropout(0.5))
  47. model.add(Dense(2))
  48. model.add(Activation('softmax'))
  49. model.compile(loss='categorical_crossentropy',
  50. optimizer='adam',
  51. #metrics=['accuracy', metrics.auc])
  52. metrics=['accuracy'])
  53. return model
  54. def generate_model_3D(_input_shape):
  55. model = Sequential()
  56. print(_input_shape)
  57. model.add(Conv3D(200, (1, 3, 3), input_shape=_input_shape))
  58. model.add(Activation('relu'))
  59. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  60. model.add(Conv3D(100, (1, 3, 3)))
  61. model.add(Activation('relu'))
  62. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  63. model.add(Conv3D(40, (1, 3, 3)))
  64. model.add(Activation('relu'))
  65. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  66. model.add(Flatten())
  67. model.add(Dense(256))
  68. model.add(Activation('relu'))
  69. model.add(BatchNormalization())
  70. model.add(Dropout(0.5))
  71. model.add(Dense(128))
  72. model.add(Activation('relu'))
  73. model.add(BatchNormalization())
  74. model.add(Dropout(0.5))
  75. model.add(Dense(64))
  76. model.add(Activation('relu'))
  77. model.add(BatchNormalization())
  78. model.add(Dropout(0.5))
  79. model.add(Dense(20))
  80. model.add(Activation('relu'))
  81. model.add(BatchNormalization())
  82. model.add(Dropout(0.5))
  83. model.add(Dense(2))
  84. model.add(Activation('sigmoid'))
  85. model.compile(loss='categorical_crossentropy',
  86. optimizer='rmsprop',
  87. #metrics=['accuracy', metrics.auc])
  88. metrics=['accuracy'])
  89. return model
  90. # using transfer learning (VGG19)
  91. def generate_model_3D_TL(_input_shape):
  92. # load pre-trained model
  93. model = VGG19(weights='imagenet', include_top=False, input_shape=_input_shape)
  94. # display model layers
  95. model.summary()
  96. # do not train convolutional layers
  97. for layer in model.layers[:5]:
  98. layer.trainable = False
  99. '''predictions_model = Sequential(model)
  100. predictions_model.add(Flatten(model.output))
  101. predictions_model.add(Dense(1024))
  102. predictions_model.add(Activation('relu'))
  103. predictions_model.add(BatchNormalization())
  104. predictions_model.add(Dropout(0.5))
  105. predictions_model.add(Dense(512))
  106. predictions_model.add(Activation('relu'))
  107. predictions_model.add(BatchNormalization())
  108. predictions_model.add(Dropout(0.5))
  109. predictions_model.add(Dense(256))
  110. predictions_model.add(Activation('relu'))
  111. predictions_model.add(BatchNormalization())
  112. model.add(Dropout(0.5))
  113. predictions_model.add(Dense(100))
  114. predictions_model.add(Activation('relu'))
  115. predictions_model.add(BatchNormalization())
  116. predictions_model.add(Dropout(0.5))
  117. predictions_model.add(Dense(20))
  118. predictions_model.add(Activation('relu'))
  119. predictions_model.add(BatchNormalization())
  120. predictions_model.add(Dropout(0.5))
  121. predictions_model.add(Dense(1))
  122. predictions_model.add(Activation('sigmoid'))'''
  123. # adding custom Layers
  124. x = model.output
  125. x = Flatten()(x)
  126. x = Dense(1024, activation="relu")(x)
  127. x = BatchNormalization()(x)
  128. x = Dropout(0.5)(x)
  129. x = Dense(256, activation="relu")(x)
  130. x = BatchNormalization()(x)
  131. x = Dropout(0.5)(x)
  132. x = Dense(64, activation="relu")(x)
  133. x = BatchNormalization()(x)
  134. x = Dropout(0.5)(x)
  135. x = Dense(16, activation="relu")(x)
  136. predictions = Dense(1, activation="softmax")(x)
  137. # creating the final model
  138. model_final = Model(input=model.input, output=predictions)
  139. model_final.summary()
  140. model_final.compile(loss='binary_crossentropy',
  141. optimizer='rmsprop',
  142. # metrics=['accuracy', metrics.auc])
  143. metrics=['accuracy'])
  144. return model_final
  145. def get_model(n_channels, _input_shape, _tl=False):
  146. if _tl:
  147. if n_channels == 3:
  148. return generate_model_3D_TL(_input_shape)
  149. else:
  150. print("Can't use transfer learning with only 1 channel")
  151. if n_channels == 1:
  152. return generate_model_2D(_input_shape)
  153. if n_channels >= 2:
  154. return generate_model_3D(_input_shape)