Parcourir la source

SVD model updated

jbuisine il y a 5 ans
Parent
commit
6cdc2f0226
2 fichiers modifiés avec 11 ajouts et 5 suppressions
  1. 5 0
      classification_cnn_keras.py
  2. 6 5
      classification_cnn_keras_svd.py

+ 5 - 0
classification_cnn_keras.py

@@ -59,6 +59,7 @@ Method which returns model to train
 def generate_model():
 def generate_model():
 
 
     model = Sequential()
     model = Sequential()
+
     model.add(Conv2D(60, (2, 2), input_shape=input_shape))
     model.add(Conv2D(60, (2, 2), input_shape=input_shape))
     model.add(Activation('relu'))
     model.add(Activation('relu'))
     model.add(MaxPooling2D(pool_size=(2, 2)))
     model.add(MaxPooling2D(pool_size=(2, 2)))
@@ -71,6 +72,10 @@ def generate_model():
     model.add(Activation('relu'))
     model.add(Activation('relu'))
     model.add(MaxPooling2D(pool_size=(2, 2)))
     model.add(MaxPooling2D(pool_size=(2, 2)))
 
 
+    model.add(Conv2D(10, (2, 2)))
+    model.add(Activation('relu'))
+    model.add(MaxPooling2D(pool_size=(2, 2)))
+
     model.add(Flatten())
     model.add(Flatten())
     model.add(Dense(60))
     model.add(Dense(60))
     model.add(Activation('relu'))
     model.add(Activation('relu'))

+ 6 - 5
classification_cnn_keras_svd.py

@@ -78,6 +78,7 @@ def generate_model():
     model.add(Activation('relu'))
     model.add(Activation('relu'))
     model.add(MaxPooling2D(pool_size=(2, 1)))
     model.add(MaxPooling2D(pool_size=(2, 1)))
 
 
+    model.add(Flatten())
     model.add(Dense(70, kernel_regularizer=l2(0.01)))
     model.add(Dense(70, kernel_regularizer=l2(0.01)))
     model.add(BatchNormalization())
     model.add(BatchNormalization())
     model.add(Activation('relu'))
     model.add(Activation('relu'))
@@ -115,10 +116,10 @@ def load_train_data():
 
 
     # this is the augmentation configuration we will use for training
     # this is the augmentation configuration we will use for training
     train_datagen = ImageDataGenerator(
     train_datagen = ImageDataGenerator(
-        rescale=1. / 255,
-        shear_range=0.2,
-        zoom_range=0.2,
-        horizontal_flip=True,
+        #rescale=1. / 255,
+        #shear_range=0.2,
+        #zoom_range=0.2,
+        #horizontal_flip=True,
         preprocessing_function=svd_metric.get_s_model_data)
         preprocessing_function=svd_metric.get_s_model_data)
 
 
     train_generator = train_datagen.flow_from_directory(
     train_generator = train_datagen.flow_from_directory(
@@ -138,7 +139,7 @@ def load_validation_data():
     # this is the augmentation configuration we will use for testing:
     # this is the augmentation configuration we will use for testing:
     # only rescaling
     # only rescaling
     test_datagen = ImageDataGenerator(
     test_datagen = ImageDataGenerator(
-        rescale=1. / 255,
+        #rescale=1. / 255,
         preprocessing_function=svd_metric.get_s_model_data)
         preprocessing_function=svd_metric.get_s_model_data)
 
 
     validation_generator = test_datagen.flow_from_directory(
     validation_generator = test_datagen.flow_from_directory(