classification_cnn_keras_svd_img.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. '''This script goes along the blog post
  2. "Building powerful image classification models using very little data"
  3. from blog.keras.io.
  4. ```
  5. data/
  6. train/
  7. final/
  8. final001.png
  9. final002.png
  10. ...
  11. noisy/
  12. noisy001.png
  13. noisy002.png
  14. ...
  15. validation/
  16. final/
  17. final001.png
  18. final002.png
  19. ...
  20. noisy/
  21. noisy001.png
  22. noisy002.png
  23. ...
  24. ```
  25. '''
  26. import sys, os, getopt
  27. import json
  28. from keras.preprocessing.image import ImageDataGenerator
  29. from keras.models import Sequential
  30. from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
  31. from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
  32. from keras.optimizers import Adam
  33. from keras.regularizers import l2
  34. from keras import backend as K
  35. from keras.utils import plot_model
  36. import tensorflow as tf
  37. import numpy as np
  38. from modules.model_helper import plot_info
  39. from modules.image_metrics import svd_metric
  40. import matplotlib.pyplot as plt
  41. # preprocessing of images
  42. from path import Path
  43. from PIL import Image
  44. import shutil
  45. import time
  46. ##########################################
  47. # Global parameters (with default value) #
  48. #### ######################################
  49. img_width, img_height = 100, 100
  50. train_data_dir = 'data_svd_**img_size**/train'
  51. validation_data_dir = 'data_svd_**img_size**/validation'
  52. nb_train_samples = 7200
  53. nb_validation_samples = 3600
  54. epochs = 50
  55. batch_size = 16
  56. input_shape = (3, img_width, img_height)
  57. ###########################################
  58. def init_directory(img_size, generate_data):
  59. img_size_str = str(img_size)
  60. svd_data_folder = str('data_svd_' + img_size_str)
  61. if os.path.exists(svd_data_folder) and 'y' in generate_data:
  62. print("Removing all previous data...")
  63. shutil.rmtree(svd_data_folder)
  64. if not os.path.exists(svd_data_folder):
  65. print("Creating new data... Just take coffee... Or two...")
  66. os.makedirs(str(train_data_dir.replace('**img_size**', img_size_str) + '/final'))
  67. os.makedirs(str(train_data_dir.replace('**img_size**', img_size_str) + '/noisy'))
  68. os.makedirs(str(validation_data_dir.replace('**img_size**', img_size_str) + '/final'))
  69. os.makedirs(str(validation_data_dir.replace('**img_size**', img_size_str) + '/noisy'))
  70. for f in Path('./data').walkfiles():
  71. if 'png' in f:
  72. img = Image.open(f)
  73. new_img = svd_metric.get_s_model_data_img(img)
  74. new_img_path = f.replace('./data', str('./' + svd_data_folder))
  75. new_img.save(new_img_path)
  76. print(new_img_path)
  77. '''
  78. Method which returns model to train
  79. @return : DirectoryIterator
  80. '''
  81. def generate_model():
  82. model = Sequential()
  83. model.add(Conv2D(100, (2, 2), input_shape=input_shape))
  84. model.add(Activation('relu'))
  85. model.add(BatchNormalization())
  86. model.add(MaxPooling2D(pool_size=(2, 2)))
  87. model.add(Conv2D(60, (2, 2), input_shape=input_shape))
  88. model.add(Activation('relu'))
  89. model.add(BatchNormalization())
  90. model.add(MaxPooling2D(pool_size=(2, 2)))
  91. model.add(Conv2D(40, (2, 2)))
  92. model.add(Activation('relu'))
  93. model.add(MaxPooling2D(pool_size=(2, 2)))
  94. model.add(Conv2D(30, (2, 2)))
  95. model.add(Activation('relu'))
  96. model.add(MaxPooling2D(pool_size=(2, 2)))
  97. model.add(Flatten())
  98. model.add(Dense(150, kernel_regularizer=l2(0.01)))
  99. model.add(BatchNormalization())
  100. model.add(Activation('relu'))
  101. model.add(Dropout(0.2))
  102. model.add(Dense(120, kernel_regularizer=l2(0.01)))
  103. model.add(BatchNormalization())
  104. model.add(Activation('relu'))
  105. model.add(Dropout(0.2))
  106. model.add(Dense(80, kernel_regularizer=l2(0.01)))
  107. model.add(BatchNormalization())
  108. model.add(Activation('relu'))
  109. model.add(Dropout(0.2))
  110. model.add(Dense(40, kernel_regularizer=l2(0.01)))
  111. model.add(BatchNormalization())
  112. model.add(Activation('relu'))
  113. model.add(Dropout(0.2))
  114. model.add(Dense(20, kernel_regularizer=l2(0.01)))
  115. model.add(BatchNormalization())
  116. model.add(Activation('relu'))
  117. model.add(Dropout(0.1))
  118. model.add(Dense(1))
  119. model.add(Activation('sigmoid'))
  120. model.compile(loss='binary_crossentropy',
  121. optimizer='rmsprop',
  122. metrics=['accuracy'])
  123. return model
  124. '''
  125. Method which loads train data
  126. @return : DirectoryIterator
  127. '''
  128. def load_train_data():
  129. # this is the augmentation configuration we will use for training
  130. train_datagen = ImageDataGenerator(
  131. rescale=1. / 255,
  132. #shear_range=0.2,
  133. #zoom_range=0.2,
  134. #horizontal_flip=True,
  135. #preprocessing_function=svd_metric.get_s_model_data_img
  136. )
  137. train_generator = train_datagen.flow_from_directory(
  138. train_data_dir,
  139. target_size=(img_width, img_height),
  140. batch_size=batch_size,
  141. class_mode='binary')
  142. return train_generator
  143. '''
  144. Method which loads validation data
  145. @return : DirectoryIterator
  146. '''
  147. def load_validation_data():
  148. # this is the augmentation configuration we will use for testing:
  149. # only rescaling
  150. test_datagen = ImageDataGenerator(
  151. rescale=1. / 255,
  152. #preprocessing_function=svd_metric.get_s_model_data_img
  153. )
  154. validation_generator = test_datagen.flow_from_directory(
  155. validation_data_dir,
  156. target_size=(img_width, img_height),
  157. batch_size=batch_size,
  158. class_mode='binary')
  159. return validation_generator
  160. def main():
  161. # update global variable and not local
  162. global batch_size
  163. global epochs
  164. global input_shape
  165. global train_data_dir
  166. global validation_data_dir
  167. global nb_train_samples
  168. global nb_validation_samples
  169. if len(sys.argv) <= 1:
  170. print('Run with default parameters...')
  171. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx --generate (y/n)')
  172. sys.exit(2)
  173. try:
  174. opts, args = getopt.getopt(sys.argv[1:], "ho:d:b:e:i:g", ["help", "output=", "directory=", "batch_size=", "epochs=", "img=", "generate="])
  175. except getopt.GetoptError:
  176. # print help information and exit:
  177. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx --generate (y/n)')
  178. sys.exit(2)
  179. for o, a in opts:
  180. if o == "-h":
  181. print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx --generate (y/n)')
  182. sys.exit()
  183. elif o in ("-o", "--output"):
  184. filename = a
  185. elif o in ("-b", "--batch_size"):
  186. batch_size = int(a)
  187. elif o in ("-e", "--epochs"):
  188. epochs = int(a)
  189. elif o in ("-d", "--directory"):
  190. directory = a
  191. elif o in ("-i", "--img"):
  192. image_size = int(a)
  193. elif o in ("-g", "--generate"):
  194. generate_data = a
  195. else:
  196. assert False, "unhandled option"
  197. # 3 because we have 3 color canals
  198. if K.image_data_format() == 'channels_first':
  199. input_shape = (3, img_width, img_height)
  200. else:
  201. input_shape = (img_width, img_height, 3)
  202. img_str_size = str(image_size)
  203. train_data_dir = str(train_data_dir.replace('**img_size**', img_str_size))
  204. validation_data_dir = str(validation_data_dir.replace('**img_size**', img_str_size))
  205. # configuration
  206. with open('config.json') as json_data:
  207. d = json.load(json_data)
  208. try:
  209. nb_train_samples = d[str(image_size)]['nb_train_samples']
  210. nb_validation_samples = d[str(image_size)]['nb_validation_samples']
  211. except:
  212. print("--img parameter missing of invalid (--image_width xx --img_height xx)")
  213. sys.exit(2)
  214. init_directory(image_size, generate_data)
  215. # load of model
  216. model = generate_model()
  217. model.summary()
  218. if(directory):
  219. print('Your model information will be saved into %s...' % directory)
  220. history = model.fit_generator(
  221. load_train_data(),
  222. steps_per_epoch=nb_train_samples // batch_size,
  223. epochs=epochs,
  224. validation_data=load_validation_data(),
  225. validation_steps=nb_validation_samples // batch_size)
  226. # if user needs output files
  227. if(filename):
  228. # update filename by folder
  229. if(directory):
  230. # create folder if necessary
  231. if not os.path.exists(directory):
  232. os.makedirs(directory)
  233. filename = directory + "/" + filename
  234. fig_size = plt.rcParams["figure.figsize"]
  235. fig_size[0] = 9
  236. fig_size[1] = 9
  237. plt.rcParams["figure.figsize"] = fig_size
  238. # save plot file history
  239. plot_info.save(history, filename)
  240. plot_model(model, to_file=str(('%s.png' % filename)), show_shapes=True)
  241. model.save_weights(str('%s.h5' % filename))
  242. if __name__ == "__main__":
  243. main()