Parcourir la source

add of dynamic retraining param

Jérôme BUISINE il y a 3 ans
Parent
commit
ba4c000358
2 fichiers modifiés avec 8 ajouts et 4 suppressions
  1. 3 3
      find_best_attributes_surrogate_openML.py
  2. 5 1
      optimization/ILSSurrogate.py

+ 3 - 3
find_best_attributes_surrogate_openML.py

@@ -111,7 +111,7 @@ def main():
     parser = argparse.ArgumentParser(description="Train and find best filters to use for model")
 
     parser.add_argument('--data', type=str, help='open ml dataset filename prefix', required=True)
-    #parser.add_argument('--start_surrogate', type=int, help='number of evalution before starting surrogare model', default=100)
+    parser.add_argument('--every_ls', type=int, help='train every ls surrogate model', default=50) # default value
     parser.add_argument('--ils', type=int, help='number of total iteration for ils algorithm', required=True)
     parser.add_argument('--ls', type=int, help='number of iteration for Local Search algorithm', required=True)
     parser.add_argument('--output', type=str, help='output surrogate model name')
@@ -119,7 +119,7 @@ def main():
     args = parser.parse_args()
 
     p_data_file = args.data
-    #p_start     = args.start_surrogate
+    p_every_ls     = args.every_ls
     p_ils_iteration = args.ils
     p_ls_iteration  = args.ls
     p_output = args.output
@@ -216,7 +216,7 @@ def main():
                         _surrogate_file_path=surrogate_output_model,
                         _start_train_surrogate=p_start, # start learning and using surrogate after 1000 real evaluation
                         _solutions_file=surrogate_output_data,
-                        _ls_train_surrogate=1,
+                        _ls_train_surrogate=p_every_ls, # retrain surrogate every 5 iteration
                         _maximise=True)
     
     algo.addCallback(BasicCheckpoint(_every=1, _filepath=backup_file_path))

+ 5 - 1
optimization/ILSSurrogate.py

@@ -83,6 +83,7 @@ class ILSSurrogate(Algorithm):
         surrogate = WalshSurrogate(order=2, size=problem.size, model=model)
         analysis = FitterAnalysis(logfile="train_surrogate.log", problem=problem)
         algo = FitterAlgo(problem=problem, surrogate=surrogate, analysis=analysis, seed=problem.seed)
+        self.analysis = analysis
 
         # dynamic number of samples based on dataset real evaluations
         nsamples = None
@@ -225,8 +226,11 @@ class ILSSurrogate(Algorithm):
 
                 self.progress()
 
+            # check using specific dynamic criteria based on r^2
+            training_surrogate_every = int(self.analysis.coefficient_of_determination(self.surrogate) * self.ls_train_surrogate)
+
             # check if necessary or not to train again surrogate
-            if self.n_local_search % self.ls_train_surrogate == 0 and self.start_train_surrogate <= self.getGlobalEvaluation():
+            if self.n_local_search % training_surrogate_every == 0 and self.start_train_surrogate <= self.getGlobalEvaluation():
 
                 # train again surrogate on real evaluated solutions file
                 self.train_surrogate()