cnn_models.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from keras.preprocessing.image import ImageDataGenerator
  2. from keras.models import Sequential
  3. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  4. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  5. from keras import backend as K
  6. import tensorflow as tf
  7. # trick to enable import of config
  8. import sys
  9. sys.path.insert(0, '..')
  10. from . import metrics
  11. from config import cnn_config as cfg
  12. def generate_model_2D(_input_shape):
  13. model = Sequential()
  14. model.add(Conv2D(60, (2, 2), input_shape=_input_shape))
  15. model.add(Activation('relu'))
  16. model.add(MaxPooling2D(pool_size=(2, 2)))
  17. model.add(Conv2D(40, (2, 2)))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Conv2D(20, (2, 2)))
  21. model.add(Activation('relu'))
  22. model.add(MaxPooling2D(pool_size=(2, 2)))
  23. model.add(Flatten())
  24. model.add(Dense(140))
  25. model.add(Activation('relu'))
  26. model.add(BatchNormalization())
  27. model.add(Dropout(0.4))
  28. model.add(Dense(120))
  29. model.add(Activation('relu'))
  30. model.add(BatchNormalization())
  31. model.add(Dropout(0.4))
  32. model.add(Dense(80))
  33. model.add(Activation('relu'))
  34. model.add(BatchNormalization())
  35. model.add(Dropout(0.4))
  36. model.add(Dense(40))
  37. model.add(Activation('relu'))
  38. model.add(BatchNormalization())
  39. model.add(Dropout(0.4))
  40. model.add(Dense(20))
  41. model.add(Activation('relu'))
  42. model.add(BatchNormalization())
  43. model.add(Dropout(0.4))
  44. model.add(Dense(1))
  45. model.add(Activation('sigmoid'))
  46. model.compile(loss='binary_crossentropy',
  47. optimizer='rmsprop',
  48. metrics=['accuracy', metrics.auc])
  49. return model
  50. def generate_model_3D(_input_shape):
  51. model = Sequential()
  52. print(_input_shape)
  53. model.add(Conv3D(60, (1, 2, 2), input_shape=_input_shape))
  54. model.add(Activation('relu'))
  55. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  56. model.add(Conv3D(40, (1, 2, 2)))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  59. model.add(Conv3D(20, (1, 2, 2)))
  60. model.add(Activation('relu'))
  61. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  62. model.add(Flatten())
  63. model.add(Dense(140))
  64. model.add(Activation('relu'))
  65. model.add(BatchNormalization())
  66. model.add(Dropout(0.4))
  67. model.add(Dense(120))
  68. model.add(Activation('relu'))
  69. model.add(BatchNormalization())
  70. model.add(Dropout(0.4))
  71. model.add(Dense(80))
  72. model.add(Activation('relu'))
  73. model.add(BatchNormalization())
  74. model.add(Dropout(0.4))
  75. model.add(Dense(40))
  76. model.add(Activation('relu'))
  77. model.add(BatchNormalization())
  78. model.add(Dropout(0.4))
  79. model.add(Dense(20))
  80. model.add(Activation('relu'))
  81. model.add(BatchNormalization())
  82. model.add(Dropout(0.4))
  83. model.add(Dense(1))
  84. model.add(Activation('sigmoid'))
  85. model.compile(loss='binary_crossentropy',
  86. optimizer='rmsprop',
  87. metrics=['accuracy', metrics.auc])
  88. return model
  89. def get_model(n_channels, _input_shape):
  90. if n_channels == 1:
  91. return generate_model_2D(_input_shape)
  92. if n_channels == 3:
  93. return generate_model_3D(_input_shape)