Parcourir la source

reinit ls fixed

Jérôme BUISINE il y a 3 ans
Parent
commit
eec7fb0f8c
1 fichiers modifiés avec 7 ajouts et 2 suppressions
  1. 7 2
      optimization/ILSPopSurrogate.py

+ 7 - 2
optimization/ILSPopSurrogate.py

@@ -79,6 +79,7 @@ class ILSPopSurrogate(Algorithm):
                 validator, maximise, parent)
 
         self._n_local_search = 0
+        self._ls_local_search = 0
         self._main_evaluator = evaluator
 
         self._surrogate_file_path = surrogate_file_path
@@ -321,13 +322,13 @@ class ILSPopSurrogate(Algorithm):
                 training_surrogate_every = int(r_squared * self._ls_train_surrogate)
                 print(f"=> R² of surrogate is of {r_squared}.")
                 print(f"=> MAE of surrogate is of {mae}.")
-                print(f'=> Retraining model every {training_surrogate_every} LS ({self._n_local_search % training_surrogate_every} of {training_surrogate_every})')
+                print(f'=> Retraining model every {training_surrogate_every} LS ({self._ls_local_search % training_surrogate_every} of {training_surrogate_every})')
                 # avoid issue when lauching every each local search
                 if training_surrogate_every <= 0:
                     training_surrogate_every = 1
 
                 # check if necessary or not to train again surrogate
-                if self._n_local_search % training_surrogate_every == 0 and self._start_train_surrogate <= self.getGlobalEvaluation():
+                if self._ls_local_search % training_surrogate_every == 0 and self._start_train_surrogate <= self.getGlobalEvaluation():
 
                     # train again surrogate on real evaluated solutions file
                     start_training = time.time()
@@ -339,8 +340,12 @@ class ILSPopSurrogate(Algorithm):
                     # reload new surrogate function
                     self.load_surrogate()
 
+                    # reinit ls search
+                    self._ls_local_search = 0
+
                 # increase number of local search done
                 self._n_local_search += 1
+                self._ls_local_search += 1
 
                 self.information()