classification_cnn_keras_cross_validation.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. '''This script goes along the blog post
  2. "Building powerful image classification models using very little data"
  3. from blog.keras.io.
  4. ```
  5. data/
  6. train/
  7. final/
  8. final001.png
  9. final002.png
  10. ...
  11. noisy/
  12. noisy001.png
  13. noisy002.png
  14. ...
  15. validation/
  16. final/
  17. final001.png
  18. final002.png
  19. ...
  20. noisy/
  21. noisy001.png
  22. noisy002.png
  23. ...
  24. ```
  25. '''
  26. import sys, os, getopt
  27. from keras.preprocessing.image import ImageDataGenerator
  28. from keras.models import Sequential
  29. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
  30. from keras.layers import Activation, Dropout, Flatten, Dense
  31. from keras import backend as K
  32. from keras.utils import plot_model
  33. from modules.model_helper import plot_info
  34. # dimensions of our images.
  35. img_width, img_height = 100, 100
  36. train_data_dir = 'data/train'
  37. validation_data_dir = 'data/validation'
  38. nb_train_samples = 7200
  39. nb_validation_samples = 3600
  40. epochs = 50
  41. batch_size = 16
  42. if K.image_data_format() == 'channels_first':
  43. input_shape = (3, img_width, img_height)
  44. else:
  45. input_shape = (img_width, img_height, 3)
  46. '''
  47. Method which returns model to train
  48. @return : DirectoryIterator
  49. '''
  50. def generate_model():
  51. # create your model using this function
  52. model = Sequential()
  53. model.add(Conv2D(60, (2, 2), input_shape=input_shape))
  54. model.add(Activation('relu'))
  55. model.add(MaxPooling2D(pool_size=(2, 2)))
  56. model.add(Conv2D(40, (2, 2)))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling2D(pool_size=(2, 2)))
  59. model.add(Conv2D(20, (2, 2)))
  60. model.add(Activation('relu'))
  61. model.add(MaxPooling2D(pool_size=(2, 2)))
  62. model.add(Conv2D(40, (2, 2)))
  63. model.add(Activation('relu'))
  64. model.add(MaxPooling2D(pool_size=(2, 2)))
  65. model.add(Conv2D(20, (2, 2)))
  66. model.add(Activation('relu'))
  67. model.add(MaxPooling2D(pool_size=(2, 2)))
  68. model.add(Flatten())
  69. model.add(Dense(256))
  70. model.add(Activation('relu'))
  71. model.add(Dropout(0.2))
  72. model.add(Dense(128))
  73. model.add(Activation('relu'))
  74. model.add(Dropout(0.2))
  75. model.add(Dense(64))
  76. model.add(Activation('relu'))
  77. model.add(Dropout(0.2))
  78. model.add(Dense(32))
  79. model.add(Activation('relu'))
  80. model.add(Dropout(0.05))
  81. model.add(Dense(1))
  82. model.add(Activation('sigmoid'))
  83. model.compile(loss='binary_crossentropy',
  84. optimizer='rmsprop',
  85. metrics=['accuracy'])
  86. return model
  87. def load_data():
  88. # load your data using this function
  89. # this is the augmentation configuration we will use for training
  90. train_datagen = ImageDataGenerator(
  91. rescale=1. / 255,
  92. shear_range=0.2,
  93. zoom_range=0.2,
  94. horizontal_flip=True)
  95. train_generator = train_datagen.flow_from_directory(
  96. train_data_dir,
  97. target_size=(img_width, img_height),
  98. batch_size=batch_size,
  99. class_mode='binary')
  100. return train_generator
  101. def train_and_evaluate_model(model, data_train, data_test):
  102. return model.fit_generator(
  103. data_train,
  104. steps_per_epoch=nb_train_samples // batch_size,
  105. epochs=epochs,
  106. shuffle=True,
  107. validation_data=data_test,
  108. validation_steps=nb_validation_samples // batch_size)
  109. def main():
  110. global batch_size
  111. global epochs
  112. if len(sys.argv) <= 1:
  113. print('No output file defined...')
  114. print('classification_cnn_keras_svd.py --output xxxxx')
  115. sys.exit(2)
  116. try:
  117. opts, args = getopt.getopt(sys.argv[1:], "ho:b:e:d", ["help", "directory=", "output=", "batch_size=", "epochs="])
  118. except getopt.GetoptError:
  119. # print help information and exit:
  120. print('classification_cnn_keras_svd.py --output xxxxx')
  121. sys.exit(2)
  122. for o, a in opts:
  123. if o == "-h":
  124. print('classification_cnn_keras_svd.py --output xxxxx')
  125. sys.exit()
  126. elif o in ("-o", "--output"):
  127. filename = a
  128. elif o in ("-b", "--batch_size"):
  129. batch_size = int(a)
  130. elif o in ("-e", "--epochs"):
  131. epochs = int(a)
  132. elif o in ("-d", "--directory"):
  133. directory = a
  134. else:
  135. assert False, "unhandled option"
  136. # load of model
  137. model = generate_model()
  138. model.summary()
  139. n_folds = 10
  140. data_generator = ImageDataGenerator(rescale=1./255, validation_split=0.33)
  141. # check if possible to not do this thing each time
  142. train_generator = data_generator.flow_from_directory(train_data_dir, target_size=(img_width, img_height), shuffle=True, seed=13,
  143. class_mode='binary', batch_size=batch_size, subset="training")
  144. validation_generator = data_generator.flow_from_directory(train_data_dir, target_size=(img_width, img_height), shuffle=True, seed=13,
  145. class_mode='binary', batch_size=batch_size, subset="validation")
  146. # now run model
  147. history = train_and_evaluate_model(model, train_generator, validation_generator)
  148. print("directory %s " % directory)
  149. if(directory):
  150. print('Your model information will be saved into %s...' % directory)
  151. # if user needs output files
  152. if(filename):
  153. # update filename by folder
  154. if(directory):
  155. # create folder if necessary
  156. if not os.path.exists(directory):
  157. os.makedirs(directory)
  158. filename = directory + "/" + filename
  159. # save plot file history
  160. plot_info.save(history, filename)
  161. plot_model(model, to_file=str(('%s.png' % filename)))
  162. model.save_weights(str('%s.h5' % filename))
  163. if __name__ == "__main__":
  164. main()