|
@@ -28,7 +28,7 @@ import json
|
|
|
|
|
|
from keras.preprocessing.image import ImageDataGenerator
|
|
|
from keras.models import Sequential
|
|
|
-from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
|
|
|
+from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Cropping2D
|
|
|
from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization
|
|
|
from keras.optimizers import Adam
|
|
|
from keras.regularizers import l2
|
|
@@ -83,7 +83,7 @@ def init_directory(img_size, generate_data):
|
|
|
|
|
|
os.makedirs(str(validation_data_dir.replace('**img_size**', img_size_str) + '/final'))
|
|
|
os.makedirs(str(validation_data_dir.replace('**img_size**', img_size_str) + '/noisy'))
|
|
|
-
|
|
|
+
|
|
|
for f in Path('./data').walkfiles():
|
|
|
if 'png' in f:
|
|
|
img = Image.open(f)
|
|
@@ -101,30 +101,23 @@ def generate_model():
|
|
|
|
|
|
model = Sequential()
|
|
|
|
|
|
- model.add(Conv2D(50, (2, 2), input_shape=input_shape))
|
|
|
- model.add(Activation('relu'))
|
|
|
- model.add(MaxPooling2D(pool_size=(2, 2)))
|
|
|
+ model.add(Cropping2D(cropping=((20, 20), (20, 20)), input_shape=input_shape))
|
|
|
|
|
|
- model.add(Conv2D(30, (2, 2)))
|
|
|
+ model.add(Conv2D(50, (2, 2)))
|
|
|
model.add(Activation('relu'))
|
|
|
- model.add(MaxPooling2D(pool_size=(2, 2)))
|
|
|
+ model.add(AveragePooling2D(pool_size=(2, 2)))
|
|
|
|
|
|
model.add(Flatten())
|
|
|
|
|
|
model.add(Dense(100, kernel_regularizer=l2(0.01)))
|
|
|
- model.add(BatchNormalization())
|
|
|
model.add(Activation('relu'))
|
|
|
+ model.add(BatchNormalization())
|
|
|
model.add(Dropout(0.2))
|
|
|
|
|
|
model.add(Dense(100, kernel_regularizer=l2(0.01)))
|
|
|
- model.add(BatchNormalization())
|
|
|
model.add(Activation('relu'))
|
|
|
- model.add(Dropout(0.2))
|
|
|
-
|
|
|
- model.add(Dense(20, kernel_regularizer=l2(0.01)))
|
|
|
model.add(BatchNormalization())
|
|
|
- model.add(Activation('relu'))
|
|
|
- model.add(Dropout(0.1))
|
|
|
+ model.add(Dropout(0.5))
|
|
|
|
|
|
model.add(Dense(1))
|
|
|
model.add(Activation('sigmoid'))
|
|
@@ -183,12 +176,12 @@ def main():
|
|
|
|
|
|
# update global variable and not local
|
|
|
global batch_size
|
|
|
- global epochs
|
|
|
+ global epochs
|
|
|
global input_shape
|
|
|
global train_data_dir
|
|
|
global validation_data_dir
|
|
|
global nb_train_samples
|
|
|
- global nb_validation_samples
|
|
|
+ global nb_validation_samples
|
|
|
|
|
|
if len(sys.argv) <= 1:
|
|
|
print('Run with default parameters...')
|