classification_cnn_keras_svd.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. '''This script goes along the blog post
  2. "Building powerful image classification models using very little data"
  3. from blog.keras.io.
  4. ```
  5. data/
  6. train/
  7. final/
  8. final001.png
  9. final002.png
  10. ...
  11. noisy/
  12. noisy001.png
  13. noisy002.png
  14. ...
  15. validation/
  16. final/
  17. final001.png
  18. final002.png
  19. ...
  20. noisy/
  21. noisy001.png
  22. noisy002.png
  23. ...
  24. ```
  25. '''
  26. import sys, os, getopt
  27. from keras.preprocessing.image import ImageDataGenerator
  28. from keras.models import Sequential
  29. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
  30. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  31. from keras.optimizers import Adam
  32. from keras.regularizers import l2
  33. from keras import backend as K
  34. from keras.utils import plot_model
  35. import matplotlib.pyplot as plt
  36. import tensorflow as tf
  37. import numpy as np
  38. from modules.model_helper import plot_info
  39. from modules.image_metrics import svd_metric
  40. # configuration
  41. # dimensions of our images.
  42. img_width, img_height = 100, 1
  43. train_data_dir = 'data/train'
  44. validation_data_dir = 'data/validation'
  45. nb_train_samples = 7200
  46. nb_validation_samples = 3600
  47. epochs = 200
  48. batch_size = 30
  49. if K.image_data_format() == 'channels_first':
  50. input_shape = (3, img_width, img_height)
  51. else:
  52. input_shape = (img_width, img_height, 3)
  53. '''
  54. Method which returns model to train
  55. @return : DirectoryIterator
  56. '''
  57. def generate_model():
  58. model = Sequential()
  59. model.add(Conv2D(100, (2, 1), input_shape=input_shape))
  60. model.add(Activation('relu'))
  61. model.add(MaxPooling2D(pool_size=(2, 1)))
  62. model.add(Conv2D(80, (2, 1)))
  63. model.add(Activation('relu'))
  64. model.add(AveragePooling2D(pool_size=(2, 1)))
  65. model.add(Conv2D(50, (2, 1)))
  66. model.add(Activation('relu'))
  67. model.add(MaxPooling2D(pool_size=(2, 1)))
  68. model.add(Flatten())
  69. model.add(Dense(50, kernel_regularizer=l2(0.01)))
  70. model.add(Activation('relu'))
  71. model.add(BatchNormalization())
  72. model.add(Dropout(0.1))
  73. model.add(Dense(100, kernel_regularizer=l2(0.01)))
  74. model.add(Activation('relu'))
  75. model.add(BatchNormalization())
  76. model.add(Dropout(0.1))
  77. model.add(Dense(200, kernel_regularizer=l2(0.01)))
  78. model.add(Activation('relu'))
  79. model.add(BatchNormalization())
  80. model.add(Dropout(0.2))
  81. model.add(Dense(300, kernel_regularizer=l2(0.01)))
  82. model.add(Activation('relu'))
  83. model.add(BatchNormalization())
  84. model.add(Dropout(0.3))
  85. model.add(Dense(200, kernel_regularizer=l2(0.01)))
  86. model.add(Activation('relu'))
  87. model.add(BatchNormalization())
  88. model.add(Dropout(0.2))
  89. model.add(Dense(100, kernel_regularizer=l2(0.01)))
  90. model.add(Activation('relu'))
  91. model.add(BatchNormalization())
  92. model.add(Dropout(0.1))
  93. model.add(Dense(50, kernel_regularizer=l2(0.01)))
  94. model.add(Activation('relu'))
  95. model.add(BatchNormalization())
  96. model.add(Dropout(0.1))
  97. model.add(Dense(20, kernel_regularizer=l2(0.01)))
  98. model.add(Activation('relu'))
  99. model.add(BatchNormalization())
  100. model.add(Dropout(0.1))
  101. model.add(Dense(1))
  102. model.add(Activation('sigmoid'))
  103. model.compile(loss='binary_crossentropy',
  104. optimizer='rmsprop',
  105. metrics=['accuracy'])
  106. return model
  107. '''
  108. Method which loads train data
  109. @return : DirectoryIterator
  110. '''
  111. def load_train_data():
  112. # this is the augmentation configuration we will use for training
  113. train_datagen = ImageDataGenerator(
  114. rescale=1. / 255,
  115. #shear_range=0.2,
  116. #zoom_range=0.2,
  117. #horizontal_flip=True,
  118. preprocessing_function=svd_metric.get_s_model_data)
  119. train_generator = train_datagen.flow_from_directory(
  120. train_data_dir,
  121. target_size=(img_width, img_height),
  122. batch_size=batch_size,
  123. class_mode='binary')
  124. return train_generator
  125. '''
  126. Method which loads validation data
  127. @return : DirectoryIterator
  128. '''
  129. def load_validation_data():
  130. # this is the augmentation configuration we will use for testing:
  131. # only rescaling
  132. test_datagen = ImageDataGenerator(
  133. rescale=1. / 255,
  134. preprocessing_function=svd_metric.get_s_model_data)
  135. validation_generator = test_datagen.flow_from_directory(
  136. validation_data_dir,
  137. target_size=(img_width, img_height),
  138. batch_size=batch_size,
  139. class_mode='binary')
  140. return validation_generator
  141. def main():
  142. global batch_size
  143. global epochs
  144. if len(sys.argv) <= 1:
  145. print('No output file defined...')
  146. print('classification_cnn_keras_svd.py --output xxxxx')
  147. sys.exit(2)
  148. try:
  149. opts, args = getopt.getopt(sys.argv[1:], "ho:b:e:d", ["help", "directory=", "output=", "batch_size=", "epochs="])
  150. except getopt.GetoptError:
  151. # print help information and exit:
  152. print('classification_cnn_keras_svd.py --output xxxxx')
  153. sys.exit(2)
  154. for o, a in opts:
  155. if o == "-h":
  156. print('classification_cnn_keras_svd.py --output xxxxx')
  157. sys.exit()
  158. elif o in ("-o", "--output"):
  159. filename = a
  160. elif o in ("-b", "--batch_size"):
  161. batch_size = int(a)
  162. elif o in ("-e", "--epochs"):
  163. epochs = int(a)
  164. elif o in ("-d", "--directory"):
  165. directory = a
  166. else:
  167. assert False, "unhandled option"
  168. # load of model
  169. model = generate_model()
  170. model.summary()
  171. if(directory):
  172. print('Your model information will be saved into %s...' % directory)
  173. history = model.fit_generator(
  174. load_train_data(),
  175. steps_per_epoch=nb_train_samples // batch_size,
  176. epochs=epochs,
  177. validation_data=load_validation_data(),
  178. validation_steps=nb_validation_samples // batch_size)
  179. # if user needs output files
  180. if(filename):
  181. # update filename by folder
  182. if(directory):
  183. # create folder if necessary
  184. if not os.path.exists(directory):
  185. os.makedirs(directory)
  186. filename = directory + "/" + filename
  187. # save plot file history
  188. plot_info.save(history, filename)
  189. plot_model(model, to_file=str(('%s.png' % filename)))
  190. model.save_weights(str('%s.h5' % filename))
  191. if __name__ == "__main__":
  192. main()