classification_cnn_keras_cross_validation.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. '''This script goes along the blog post
  2. "Building powerful image classification models using very little data"
  3. from blog.keras.io.
  4. ```
  5. data/
  6. train/
  7. final/
  8. final001.png
  9. final002.png
  10. ...
  11. noisy/
  12. noisy001.png
  13. noisy002.png
  14. ...
  15. validation/
  16. final/
  17. final001.png
  18. final002.png
  19. ...
  20. noisy/
  21. noisy001.png
  22. noisy002.png
  23. ...
  24. ```
  25. '''
  26. import sys, os, getopt
  27. import json
  28. from keras.preprocessing.image import ImageDataGenerator
  29. from keras.models import Sequential
  30. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
  31. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  32. from keras import backend as K
  33. from keras.utils import plot_model
  34. # local functions import (metrics preprocessing)
  35. import preprocessing_functions
  36. ##########################################
  37. # Global parameters (with default value) #
  38. ##########################################
  39. img_width, img_height = 100, 100
  40. train_data_dir = 'data/train'
  41. validation_data_dir = 'data/validation'
  42. nb_train_samples = 7200
  43. nb_validation_samples = 3600
  44. epochs = 50
  45. batch_size = 16
  46. input_shape = (3, img_width, img_height)
  47. ###########################################
  48. '''
  49. Method which returns model to train
  50. @return : DirectoryIterator
  51. '''
  52. def generate_model():
  53. # create your model using this function
  54. model = Sequential()
  55. model.add(Conv2D(60, (2, 2), input_shape=input_shape, dilation_rate=1))
  56. model.add(Activation('relu'))
  57. model.add(MaxPooling2D(pool_size=(2, 2)))
  58. model.add(Conv2D(40, (2, 2), dilation_rate=1))
  59. model.add(Activation('relu'))
  60. model.add(MaxPooling2D(pool_size=(2, 2)))
  61. model.add(Conv2D(20, (2, 2), dilation_rate=1))
  62. model.add(Activation('relu'))
  63. model.add(MaxPooling2D(pool_size=(2, 2)))
  64. model.add(Flatten())
  65. model.add(Dense(140))
  66. model.add(Activation('relu'))
  67. model.add(BatchNormalization())
  68. model.add(Dropout(0.3))
  69. model.add(Dense(120))
  70. model.add(Activation('relu'))
  71. model.add(BatchNormalization())
  72. model.add(Dropout(0.3))
  73. model.add(Dense(80))
  74. model.add(Activation('relu'))
  75. model.add(BatchNormalization())
  76. model.add(Dropout(0.2))
  77. model.add(Dense(40))
  78. model.add(Activation('relu'))
  79. model.add(BatchNormalization())
  80. model.add(Dropout(0.2))
  81. model.add(Dense(20))
  82. model.add(Activation('relu'))
  83. model.add(BatchNormalization())
  84. model.add(Dropout(0.2))
  85. model.add(Dense(1))
  86. model.add(Activation('sigmoid'))
  87. model.compile(loss='binary_crossentropy',
  88. optimizer='adam',
  89. metrics=['accuracy'])
  90. return model
  91. def load_data():
  92. # load your data using this function
  93. # this is the augmentation configuration we will use for training
  94. train_datagen = ImageDataGenerator(
  95. rescale=1. / 255,
  96. shear_range=0.2,
  97. zoom_range=0.2,
  98. horizontal_flip=True)
  99. train_generator = train_datagen.flow_from_directory(
  100. train_data_dir,
  101. target_size=(img_width, img_height),
  102. batch_size=batch_size,
  103. class_mode='binary')
  104. return train_generator
  105. def train_and_evaluate_model(model, data_train, data_test):
  106. return model.fit_generator(
  107. data_train,
  108. steps_per_epoch=nb_train_samples // batch_size,
  109. epochs=epochs,
  110. shuffle=True,
  111. validation_data=data_test,
  112. validation_steps=nb_validation_samples // batch_size)
  113. def main():
  114. # update global variable and not local
  115. global batch_size
  116. global epochs
  117. global img_width
  118. global img_height
  119. global input_shape
  120. global train_data_dir
  121. global validation_data_dir
  122. global nb_train_samples
  123. global nb_validation_samples
  124. if len(sys.argv) <= 1:
  125. print('Run with default parameters...')
  126. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
  127. sys.exit(2)
  128. try:
  129. opts, args = getopt.getopt(sys.argv[1:], "ho:d:b:e:i", ["help", "output=", "directory=", "batch_size=", "epochs=", "img="])
  130. except getopt.GetoptError:
  131. # print help information and exit:
  132. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
  133. sys.exit(2)
  134. for o, a in opts:
  135. if o == "-h":
  136. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
  137. sys.exit()
  138. elif o in ("-o", "--output"):
  139. filename = a
  140. elif o in ("-b", "--batch_size"):
  141. batch_size = int(a)
  142. elif o in ("-e", "--epochs"):
  143. epochs = int(a)
  144. elif o in ("-d", "--directory"):
  145. directory = a
  146. elif o in ("-i", "--img"):
  147. img_height = int(a)
  148. img_width = int(a)
  149. else:
  150. assert False, "unhandled option"
  151. # 3 because we have 3 color canals
  152. if K.image_data_format() == 'channels_first':
  153. input_shape = (3, img_width, img_height)
  154. else:
  155. input_shape = (img_width, img_height, 3)
  156. # configuration
  157. with open('config.json') as json_data:
  158. d = json.load(json_data)
  159. train_data_dir = d['train_data_dir']
  160. validation_data_dir = d['train_validation_dir']
  161. try:
  162. nb_train_samples = d[str(img_width)]['nb_train_samples']
  163. nb_validation_samples = d[str(img_width)]['nb_validation_samples']
  164. except:
  165. print("--img parameter missing of invalid (--image_width xx --img_height xx)")
  166. sys.exit(2)
  167. # load of model
  168. model = generate_model()
  169. model.summary()
  170. data_generator = ImageDataGenerator(rescale=1./255, validation_split=0.33)
  171. # check if possible to not do this thing each time
  172. train_generator = data_generator.flow_from_directory(train_data_dir, target_size=(img_width, img_height), shuffle=True, seed=13,
  173. class_mode='binary', batch_size=batch_size, subset="training")
  174. validation_generator = data_generator.flow_from_directory(train_data_dir, target_size=(img_width, img_height), shuffle=True, seed=13,
  175. class_mode='binary', batch_size=batch_size, subset="validation")
  176. # now run model
  177. history = train_and_evaluate_model(model, train_generator, validation_generator)
  178. print("directory %s " % directory)
  179. if(directory):
  180. print('Your model information will be saved into %s...' % directory)
  181. # if user needs output files
  182. if(filename):
  183. # update filename by folder
  184. if(directory):
  185. # create folder if necessary
  186. if not os.path.exists(directory):
  187. os.makedirs(directory)
  188. filename = directory + "/" + filename
  189. # save plot file history
  190. # tf_model_helper.save(history, filename)
  191. plot_model(model, to_file=str(('%s.png' % filename)))
  192. model.save_weights(str('%s.h5' % filename))
  193. if __name__ == "__main__":
  194. main()