models.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from keras.preprocessing.image import ImageDataGenerator
  2. from keras.models import Sequential
  3. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv3D, MaxPooling3D, AveragePooling3D
  4. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  5. from keras import backend as K
  6. import tensorflow as tf
  7. from modules.utils import config as cfg
  8. from modules.models import metrics
  9. def generate_model_2D(_input_shape):
  10. model = Sequential()
  11. model.add(Conv2D(60, (2, 2), input_shape=_input_shape))
  12. model.add(Activation('relu'))
  13. model.add(MaxPooling2D(pool_size=(2, 2)))
  14. model.add(Conv2D(40, (2, 2)))
  15. model.add(Activation('relu'))
  16. model.add(MaxPooling2D(pool_size=(2, 2)))
  17. model.add(Conv2D(20, (2, 2)))
  18. model.add(Activation('relu'))
  19. model.add(MaxPooling2D(pool_size=(2, 2)))
  20. model.add(Flatten())
  21. model.add(Dense(140))
  22. model.add(Activation('relu'))
  23. model.add(BatchNormalization())
  24. model.add(Dropout(0.4))
  25. model.add(Dense(120))
  26. model.add(Activation('relu'))
  27. model.add(BatchNormalization())
  28. model.add(Dropout(0.4))
  29. model.add(Dense(80))
  30. model.add(Activation('relu'))
  31. model.add(BatchNormalization())
  32. model.add(Dropout(0.4))
  33. model.add(Dense(40))
  34. model.add(Activation('relu'))
  35. model.add(BatchNormalization())
  36. model.add(Dropout(0.4))
  37. model.add(Dense(20))
  38. model.add(Activation('relu'))
  39. model.add(BatchNormalization())
  40. model.add(Dropout(0.4))
  41. model.add(Dense(1))
  42. model.add(Activation('sigmoid'))
  43. model.compile(loss='binary_crossentropy',
  44. optimizer='rmsprop',
  45. metrics=['accuracy', metrics.auc])
  46. return model
  47. def generate_model_3D(_input_shape):
  48. model = Sequential()
  49. print(_input_shape)
  50. model.add(Conv3D(60, (1, 2, 2), input_shape=_input_shape))
  51. model.add(Activation('relu'))
  52. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  53. model.add(Conv3D(40, (1, 2, 2)))
  54. model.add(Activation('relu'))
  55. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  56. model.add(Conv3D(20, (1, 2, 2)))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling3D(pool_size=(1, 2, 2)))
  59. model.add(Flatten())
  60. model.add(Dense(140))
  61. model.add(Activation('relu'))
  62. model.add(BatchNormalization())
  63. model.add(Dropout(0.4))
  64. model.add(Dense(120))
  65. model.add(Activation('relu'))
  66. model.add(BatchNormalization())
  67. model.add(Dropout(0.4))
  68. model.add(Dense(80))
  69. model.add(Activation('relu'))
  70. model.add(BatchNormalization())
  71. model.add(Dropout(0.4))
  72. model.add(Dense(40))
  73. model.add(Activation('relu'))
  74. model.add(BatchNormalization())
  75. model.add(Dropout(0.4))
  76. model.add(Dense(20))
  77. model.add(Activation('relu'))
  78. model.add(BatchNormalization())
  79. model.add(Dropout(0.4))
  80. model.add(Dense(1))
  81. model.add(Activation('sigmoid'))
  82. model.compile(loss='binary_crossentropy',
  83. optimizer='rmsprop',
  84. metrics=['accuracy', metrics.auc])
  85. return model
  86. def get_model(n_channels, _input_shape):
  87. if n_channels == 1:
  88. return generate_model_2D(_input_shape)
  89. if n_channels == 3:
  90. return generate_model_3D(_input_shape)