|
@@ -24,6 +24,7 @@ data/
|
|
|
```
|
|
|
'''
|
|
|
import sys, os, getopt
|
|
|
+import json
|
|
|
|
|
|
from keras.preprocessing.image import ImageDataGenerator
|
|
|
from keras.models import Sequential
|
|
@@ -35,7 +36,9 @@ from keras.utils import plot_model
|
|
|
from modules.model_helper import plot_info
|
|
|
|
|
|
|
|
|
-# dimensions of our images.
|
|
|
+##########################################
|
|
|
+# Global parameters (with default value) #
|
|
|
+##########################################
|
|
|
img_width, img_height = 100, 100
|
|
|
|
|
|
train_data_dir = 'data/train'
|
|
@@ -43,12 +46,11 @@ validation_data_dir = 'data/validation'
|
|
|
nb_train_samples = 7200
|
|
|
nb_validation_samples = 3600
|
|
|
epochs = 50
|
|
|
-batch_size = 30
|
|
|
+batch_size = 16
|
|
|
|
|
|
-if K.image_data_format() == 'channels_first':
|
|
|
- input_shape = (3, img_width, img_height)
|
|
|
-else:
|
|
|
- input_shape = (img_width, img_height, 3)
|
|
|
+input_shape = (3, img_width, img_height)
|
|
|
+
|
|
|
+###########################################
|
|
|
|
|
|
'''
|
|
|
Method which returns model to train
|
|
@@ -131,22 +133,30 @@ def load_validation_data():
|
|
|
|
|
|
def main():
|
|
|
|
|
|
+ # update global variable and not local
|
|
|
global batch_size
|
|
|
- global epochs
|
|
|
+ global epochs
|
|
|
+ global img_width
|
|
|
+ global img_height
|
|
|
+ global input_shape
|
|
|
+ global train_data_dir
|
|
|
+ global validation_data_dir
|
|
|
+ global nb_train_samples
|
|
|
+ global nb_validation_samples
|
|
|
|
|
|
if len(sys.argv) <= 1:
|
|
|
- print('No output file defined...')
|
|
|
- print('classification_cnn_keras_svd.py --output xxxxx')
|
|
|
+ print('Run with default parameters...')
|
|
|
+ print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
|
|
|
sys.exit(2)
|
|
|
try:
|
|
|
- opts, args = getopt.getopt(sys.argv[1:], "ho:b:e:d", ["help", "directory=", "output=", "batch_size=", "epochs="])
|
|
|
+ opts, args = getopt.getopt(sys.argv[1:], "ho:d:b:e:i", ["help", "output=", "directory=", "batch_size=", "epochs=", "img="])
|
|
|
except getopt.GetoptError:
|
|
|
# print help information and exit:
|
|
|
- print('classification_cnn_keras_svd.py --output xxxxx')
|
|
|
+ print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
|
|
|
sys.exit(2)
|
|
|
for o, a in opts:
|
|
|
if o == "-h":
|
|
|
- print('classification_cnn_keras_svd.py --output xxxxx')
|
|
|
+ print('classification_cnn_keras_svd.py --directory xxxx --output xxxxx --batch_size xx --epochs xx --img xx')
|
|
|
sys.exit()
|
|
|
elif o in ("-o", "--output"):
|
|
|
filename = a
|
|
@@ -156,15 +166,36 @@ def main():
|
|
|
epochs = int(a)
|
|
|
elif o in ("-d", "--directory"):
|
|
|
directory = a
|
|
|
+ elif o in ("-i", "--img"):
|
|
|
+ img_height = int(a)
|
|
|
+ img_width = int(a)
|
|
|
else:
|
|
|
assert False, "unhandled option"
|
|
|
|
|
|
-
|
|
|
+ # 3 because we have 3 color canals
|
|
|
+ if K.image_data_format() == 'channels_first':
|
|
|
+ input_shape = (3, img_width, img_height)
|
|
|
+ else:
|
|
|
+ input_shape = (img_width, img_height, 3)
|
|
|
+
|
|
|
+ # configuration
|
|
|
+ with open('config.json') as json_data:
|
|
|
+ d = json.load(json_data)
|
|
|
+ train_data_dir = d['train_data_dir']
|
|
|
+ validation_data_dir = d['train_validation_dir']
|
|
|
+
|
|
|
+ try:
|
|
|
+ nb_train_samples = d[str(img_width)]['nb_train_samples']
|
|
|
+ nb_validation_samples = d[str(img_width)]['nb_validation_samples']
|
|
|
+ except:
|
|
|
+ print("--img parameter missing of invalid (--image_width xx --img_height xx)")
|
|
|
+ sys.exit(2)
|
|
|
+
|
|
|
# load of model
|
|
|
model = generate_model()
|
|
|
model.summary()
|
|
|
|
|
|
- if(directory):
|
|
|
+ if 'directory' in locals():
|
|
|
print('Your model information will be saved into %s...' % directory)
|
|
|
|
|
|
history = model.fit_generator(
|