Parcourir la source

update use of custom every ls param

Jérôme BUISINE il y a 3 ans
Parent
commit
a94b9b8fca
2 fichiers modifiés avec 7 ajouts et 1 suppressions
  1. 3 1
      find_best_attributes_surrogate_dl.py
  2. 4 0
      optimization/ILSSurrogate.py

+ 3 - 1
find_best_attributes_surrogate_dl.py

@@ -168,6 +168,7 @@ def main():
     parser.add_argument('--length', type=int, help='max data length (need to be specify for evaluator)', required=True)
     parser.add_argument('--ils', type=int, help='number of total iteration for ils algorithm', required=True)
     parser.add_argument('--ls', type=int, help='number of iteration for Local Search algorithm', required=True)
+    parser.add_argument('--every_ls', type=int, help='number of max iteration for retraining surrogate model', required=True)
     parser.add_argument('--output', type=str, help='output surrogate model name')
 
     args = parser.parse_args()
@@ -177,6 +178,7 @@ def main():
     p_start     = args.start_surrogate
     p_ils_iteration = args.ils
     p_ls_iteration  = args.ls
+    p_every_ls      = args.every_ls
     p_output = args.output
 
     print(p_data_file)
@@ -279,7 +281,7 @@ def main():
                         _surrogate_file_path=surrogate_output_model,
                         _start_train_surrogate=p_start, # start learning and using surrogate after 1000 real evaluation
                         _solutions_file=surrogate_output_data,
-                        _ls_train_surrogate=1,
+                        _ls_train_surrogate=p_every_ls,
                         _maximise=True)
     
     algo.addCallback(BasicCheckpoint(_every=1, _filepath=backup_file_path))

+ 4 - 0
optimization/ILSSurrogate.py

@@ -230,6 +230,10 @@ class ILSSurrogate(Algorithm):
             training_surrogate_every = int(r_squared * self.ls_train_surrogate)
             print(f"=> R^2 of surrogate is of {r_squared}. Retraining model every {training_surrogate_every} LS")
 
+            # avoid issue when lauching every each local search
+            if training_surrogate_every <= 0:
+                training_surrogate_every = 1
+
             # check if necessary or not to train again surrogate
             if self.n_local_search % training_surrogate_every == 0 and self.start_train_surrogate <= self.getGlobalEvaluation():