cnn_models.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from keras.preprocessing.image import ImageDataGenerator
  2. from keras.models import Sequential
  3. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  4. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  5. from keras import backend as K
  6. import tensorflow as tf
  7. from . import metrics
  8. from ..config import cnn_config as cfg
  9. def generate_model_2D(_input_shape):
  10. model = Sequential()
  11. model.add(Conv2D(60, (2, 2), input_shape=_input_shape))
  12. model.add(Activation('relu'))
  13. model.add(MaxPooling2D(pool_size=(2, 2)))
  14. model.add(Conv2D(40, (2, 2)))
  15. model.add(Activation('relu'))
  16. model.add(MaxPooling2D(pool_size=(2, 2)))
  17. model.add(Conv2D(20, (2, 2)))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Flatten())
  21. model.add(Dense(140))
  22. model.add(Activation('relu'))
  23. model.add(BatchNormalization())
  24. model.add(Dropout(0.4))
  25. model.add(Dense(120))
  26. model.add(Activation('relu'))
  27. model.add(BatchNormalization())
  28. model.add(Dropout(0.4))
  29. model.add(Dense(80))
  30. model.add(Activation('relu'))
  31. model.add(BatchNormalization())
  32. model.add(Dropout(0.4))
  33. model.add(Dense(40))
  34. model.add(Activation('relu'))
  35. model.add(BatchNormalization())
  36. model.add(Dropout(0.4))
  37. model.add(Dense(20))
  38. model.add(Activation('relu'))
  39. model.add(BatchNormalization())
  40. model.add(Dropout(0.4))
  41. model.add(Dense(1))
  42. model.add(Activation('sigmoid'))
  43. model.compile(loss='binary_crossentropy',
  44. optimizer='rmsprop',
  45. metrics=['accuracy', metrics.auc])
  46. return model
  47. def generate_model_3D(_input_shape):
  48. model = Sequential()
  49. print(_input_shape)
  50. model.add(Conv3D(60, (1, 2, 2), input_shape=_input_shape))
  51. model.add(Activation('relu'))
  52. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  53. model.add(Conv3D(40, (1, 2, 2)))
  54. model.add(Activation('relu'))
  55. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  56. model.add(Conv3D(20, (1, 2, 2)))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  59. model.add(Flatten())
  60. model.add(Dense(140))
  61. model.add(Activation('relu'))
  62. model.add(BatchNormalization())
  63. model.add(Dropout(0.4))
  64. model.add(Dense(120))
  65. model.add(Activation('relu'))
  66. model.add(BatchNormalization())
  67. model.add(Dropout(0.4))
  68. model.add(Dense(80))
  69. model.add(Activation('relu'))
  70. model.add(BatchNormalization())
  71. model.add(Dropout(0.4))
  72. model.add(Dense(40))
  73. model.add(Activation('relu'))
  74. model.add(BatchNormalization())
  75. model.add(Dropout(0.4))
  76. model.add(Dense(20))
  77. model.add(Activation('relu'))
  78. model.add(BatchNormalization())
  79. model.add(Dropout(0.4))
  80. model.add(Dense(1))
  81. model.add(Activation('sigmoid'))
  82. model.compile(loss='binary_crossentropy',
  83. optimizer='rmsprop',
  84. metrics=['accuracy', metrics.auc])
  85. return model
  86. def get_model(n_channels, _input_shape):
  87. if n_channels == 1:
  88. return generate_model_2D(_input_shape)
  89. if n_channels == 3:
  90. return generate_model_3D(_input_shape)