Parcourir la source

Merge branch 'release/v0.4.6'

Jérôme BUISINE il y a 3 ans
Parent
commit
19b95830e0
1 fichiers modifiés avec 37 ajouts et 21 suppressions
  1. 37 21
      train_model.py

+ 37 - 21
train_model.py

@@ -31,9 +31,8 @@ def main():
     parser.add_argument('--data', type=str, help='dataset filename prefix (without .train and .val)', required=True)
     parser.add_argument('--output', type=str, help='output file name desired for model (without .json extension)', required=True)
     parser.add_argument('--tl', type=int, help='use or not of transfer learning (`VGG network`)', default=0, choices=[0, 1])
-    parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=cfg.keras_batch)
-    parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=cfg.keras_epochs)
-    parser.add_argument('--balancing', type=int, help='specify if balacing of classes is done or not', default="1")
+    parser.add_argument('--batch_size', type=int, help='batch size used as model input', default=64)
+    parser.add_argument('--epochs', type=int, help='number of epochs used for training model', default=30)
     parser.add_argument('--chanels', type=int, help="given number of chanels if necessary", default=0)
     parser.add_argument('--size', type=str, help="Size of input images", default="100, 100")
     parser.add_argument('--val_size', type=float, help='percent of validation data during training process', default=0.3)
@@ -46,7 +45,6 @@ def main():
     p_tl          = args.tl
     p_batch_size  = args.batch_size
     p_epochs      = args.epochs
-    p_balancing   = bool(args.balancing)
     p_chanels     = args.chanels
     p_size        = args.size.split(',')
     p_val_size    = args.val_size
@@ -92,23 +90,35 @@ def main():
         else:
             input_shape = (img_width, img_height, n_chanels)
 
-    # get dataset with equal number of classes occurences if wished
-    if p_balancing:
-        print("Balancing of data")
-        noisy_df_train = dataset_train[dataset_train.iloc[:, 0] == 1]
-        not_noisy_df_train = dataset_train[dataset_train.iloc[:, 0] == 0]
-        nb_noisy_train = len(noisy_df_train.index)
+    # getting weighted class over the whole dataset
+    noisy_df_train = dataset_train[dataset_train.iloc[:, 0] == 1]
+    not_noisy_df_train = dataset_train[dataset_train.iloc[:, 0] == 0]
+    nb_noisy_train = len(noisy_df_train.index)
+    nb_not_noisy_train = len(not_noisy_df_train.index)
 
-        noisy_df_val = dataset_test[dataset_test.iloc[:, 0] == 1]
-        not_noisy_df_val = dataset_test[dataset_test.iloc[:, 0] == 0]
-        nb_noisy_val = len(noisy_df_val.index)
+    noisy_df_test = dataset_test[dataset_test.iloc[:, 0] == 1]
+    not_noisy_df_test = dataset_test[dataset_test.iloc[:, 0] == 0]
+    nb_noisy_test = len(noisy_df_test.index)
+    nb_not_noisy_test = len(not_noisy_df_test.index)
 
-        final_df_train = pd.concat([not_noisy_df_train[0:nb_noisy_train], noisy_df_train])
-        final_df_val = pd.concat([not_noisy_df_val[0:nb_noisy_val], noisy_df_val])
-    else:
-        print("No balancing of data")
-        final_df_train = dataset_train
-        final_df_test = dataset_test
+    noisy_samples = nb_noisy_test + nb_noisy_train
+    not_noisy_samples = nb_not_noisy_test + nb_not_noisy_train
+
+    total_samples = noisy_samples + not_noisy_samples
+
+    print('noisy', noisy_samples)
+    print('not_noisy', not_noisy_samples)
+    print('total', total_samples)
+
+    class_weight = {
+        0: (noisy_samples / float(total_samples)),
+        1: (not_noisy_samples / float(total_samples)),
+    }
+
+
+
+    final_df_train = dataset_train
+    final_df_test = dataset_test
 
     # check if specific number of chanels is used
     if p_chanels == 0:
@@ -221,8 +231,14 @@ def main():
     y_val = to_categorical(y_val)
     y_test = to_categorical(y_dataset_test)
 
-    # validation split parameter will use the last `%` data, so here, data will really validate our model
-    model.fit(X_train, y_train, validation_data=(X_val, y_val), initial_epoch=initial_epoch, epochs=p_epochs, batch_size=p_batch_size, callbacks=callbacks_list)
+    print("Fitting model with custom class_weight", class_weight)
+    model.fit(X_train, y_train, 
+        validation_data=(X_val, y_val), 
+        initial_epoch=initial_epoch, 
+        epochs=p_epochs, 
+        batch_size=p_batch_size, 
+        callbacks=callbacks_list, 
+        class_weight=class_weight)
 
     score = model.evaluate(X_val, y_val, batch_size=p_batch_size)