classification_cnn_keras.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. '''This script goes along the blog post
  2. "Building powerful image classification models using very little data"
  3. from blog.keras.io.
  4. ```
  5. data/
  6. train/
  7. final/
  8. final001.png
  9. final002.png
  10. ...
  11. noisy/
  12. noisy001.png
  13. noisy002.png
  14. ...
  15. validation/
  16. final/
  17. final001.png
  18. final002.png
  19. ...
  20. noisy/
  21. noisy001.png
  22. noisy002.png
  23. ...
  24. ```
  25. '''
  26. import sys, os, getopt
  27. import json
  28. from keras.preprocessing.image import ImageDataGenerator
  29. from keras.models import Sequential
  30. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
  31. from keras.layers import Activation, Dropout, Flatten, Dense
  32. from keras import backend as K
  33. from keras.utils import plot_model
  34. from modules.model_helper import plot_info
  35. ##########################################
  36. # Global parameters (with default value) #
  37. ##########################################
  38. img_width, img_height = 100, 100
  39. train_data_dir = 'data/train'
  40. validation_data_dir = 'data/validation'
  41. nb_train_samples = 7200
  42. nb_validation_samples = 3600
  43. epochs = 50
  44. batch_size = 16
  45. input_shape = (3, img_width, img_height)
  46. ###########################################
  47. '''
  48. Method which returns model to train
  49. @return : DirectoryIterator
  50. '''
  51. def generate_model():
  52. model = Sequential()
  53. model.add(Conv2D(60, (2, 2), input_shape=input_shape))
  54. model.add(Activation('relu'))
  55. model.add(MaxPooling2D(pool_size=(2, 2)))
  56. model.add(Conv2D(40, (2, 2)))
  57. model.add(Activation('relu'))
  58. model.add(MaxPooling2D(pool_size=(2, 2)))
  59. model.add(Conv2D(20, (2, 2)))
  60. model.add(Activation('relu'))
  61. model.add(MaxPooling2D(pool_size=(2, 2)))
  62. model.add(Conv2D(10, (2, 2)))
  63. model.add(Activation('relu'))
  64. model.add(MaxPooling2D(pool_size=(2, 2)))
  65. model.add(Flatten())
  66. model.add(Dense(60))
  67. model.add(Activation('relu'))
  68. model.add(Dropout(0.4))
  69. model.add(Dense(30))
  70. model.add(Activation('relu'))
  71. model.add(Dropout(0.2))
  72. model.add(Dense(1))
  73. model.add(Activation('sigmoid'))
  74. model.compile(loss='binary_crossentropy',
  75. optimizer='rmsprop',
  76. metrics=['accuracy'])
  77. return model
  78. '''
  79. Method which loads train data
  80. @return : DirectoryIterator
  81. '''
  82. def load_train_data():
  83. # this is the augmentation configuration we will use for training
  84. train_datagen = ImageDataGenerator(
  85. rescale=1. / 255,
  86. shear_range=0.2,
  87. zoom_range=0.2,
  88. horizontal_flip=True)
  89. train_generator = train_datagen.flow_from_directory(
  90. train_data_dir,
  91. target_size=(img_width, img_height),
  92. batch_size=batch_size,
  93. class_mode='binary')
  94. return train_generator
  95. '''
  96. Method which loads validation data
  97. @return : DirectoryIterator
  98. '''
  99. def load_validation_data():
  100. # this is the augmentation configuration we will use for testing:
  101. # only rescaling
  102. test_datagen = ImageDataGenerator(rescale=1. / 255)
  103. validation_generator = test_datagen.flow_from_directory(
  104. validation_data_dir,
  105. target_size=(img_width, img_height),
  106. batch_size=batch_size,
  107. class_mode='binary')
  108. return validation_generator
  109. def main():
  110. # update global variable and not local
  111. global batch_size
  112. global epochs
  113. global img_width
  114. global img_height
  115. global input_shape
  116. global train_data_dir
  117. global validation_data_dir
  118. global nb_train_samples
  119. global nb_validation_samples
  120. if len(sys.argv) <= 1:
  121. print('Run with default parameters...')
  122. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
  123. sys.exit(2)
  124. try:
  125. opts, args = getopt.getopt(sys.argv[1:], "ho:d:b:e:i", ["help", "output=", "directory=", "batch_size=", "epochs=", "img="])
  126. except getopt.GetoptError:
  127. # print help information and exit:
  128. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
  129. sys.exit(2)
  130. for o, a in opts:
  131. if o == "-h":
  132. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
  133. sys.exit()
  134. elif o in ("-o", "--output"):
  135. filename = a
  136. elif o in ("-b", "--batch_size"):
  137. batch_size = int(a)
  138. elif o in ("-e", "--epochs"):
  139. epochs = int(a)
  140. elif o in ("-d", "--directory"):
  141. directory = a
  142. elif o in ("-i", "--img"):
  143. img_height = int(a)
  144. img_width = int(a)
  145. else:
  146. assert False, "unhandled option"
  147. # 3 because we have 3 color canals
  148. if K.image_data_format() == 'channels_first':
  149. input_shape = (3, img_width, img_height)
  150. else:
  151. input_shape = (img_width, img_height, 3)
  152. # configuration
  153. with open('config.json') as json_data:
  154. d = json.load(json_data)
  155. train_data_dir = d['train_data_dir']
  156. validation_data_dir = d['train_validation_dir']
  157. try:
  158. nb_train_samples = d[str(img_width)]['nb_train_samples']
  159. nb_validation_samples = d[str(img_width)]['nb_validation_samples']
  160. except:
  161. print("--img parameter missing of invalid (--image_width xx --img_height xx)")
  162. sys.exit(2)
  163. # load of model
  164. model = generate_model()
  165. model.summary()
  166. if 'directory' in locals():
  167. print('Your model information will be saved into %s...' % directory)
  168. history = model.fit_generator(
  169. load_train_data(),
  170. steps_per_epoch=nb_train_samples // batch_size,
  171. epochs=epochs,
  172. validation_data=load_validation_data(),
  173. validation_steps=nb_validation_samples // batch_size)
  174. # if user needs output files
  175. if(filename):
  176. # update filename by folder
  177. if(directory):
  178. # create folder if necessary
  179. if not os.path.exists(directory):
  180. os.makedirs(directory)
  181. filename = directory + "/" + filename
  182. # save plot file history
  183. plot_info.save(history, filename)
  184. plot_model(model, to_file=str(('%s.png' % filename)))
  185. model.save_weights(str('%s.h5' % filename))
  186. if __name__ == "__main__":
  187. main()