ソースを参照

add of mono surrogate performance checkpoint

Jérôme BUISINE 4 年 前
コミット
d495a8ce8e

+ 3 - 1
find_best_attributes_surrogate.py

@@ -41,7 +41,7 @@ from macop.policies.reinforcement import UCBPolicy
 from macop.callbacks.classicals import BasicCheckpoint
 from macop.callbacks.policies import UCBCheckpoint
 from optimization.callbacks.MultiPopCheckpoint import MultiPopCheckpoint
-
+from optimization.callbacks.SurrogateMonoCheckpoint import SurrogateMonoCheckpoint
 #from sklearn.ensemble import RandomForestClassifier
 
 # variables and parameters
@@ -204,6 +204,7 @@ def main():
 
     backup_file_path = os.path.join(backup_model_folder, p_output + '.csv')
     ucb_backup_file_path = os.path.join(backup_model_folder, p_output + '_ucbPolicy.csv')
+    surrogate_performanche_file_path = os.path.join(cfg.output_surrogates_data_folder, p_output + '_performance.csv')
 
     # prepare optimization algorithm (only use of mutation as only ILS are used here, and local search need only local permutation)
     operators = [SimpleBinaryMutation(), SimpleMutation(), RandomPopCrossover(), SimplePopCrossover()]
@@ -231,6 +232,7 @@ def main():
     
     algo.addCallback(MultiPopCheckpoint(every=1, filepath=backup_file_path))
     algo.addCallback(UCBCheckpoint(every=1, filepath=ucb_backup_file_path))
+    algo.addCallback(SurrogateMonoCheckpoint(every=1, filepath=surrogate_performanche_file_path))
 
     bestSol = algo.run(p_ils_iteration, p_ls_iteration)
 

+ 1 - 1
optimization/ILSPopSurrogate.py

@@ -295,7 +295,7 @@ class ILSPopSurrogate(Algorithm):
                     self.add_to_surrogate(newSolution)
 
                     self.progress()
-
+                    
                 self.increaseEvaluation()
 
                 print(f'Best solution found so far: {self.result.fitness}')

+ 91 - 0
optimization/callbacks/SurrogateMonoCheckpoint.py

@@ -0,0 +1,91 @@
+"""Basic Checkpoint class implementation
+"""
+
+# main imports
+import os
+import logging
+import numpy as np
+
+# module imports
+from macop.callbacks.Callback import Callback
+from macop.utils.color import macop_text, macop_line
+
+
+class SurrogateMonoCheckpoint(Callback):
+    """
+    SurrogateCheckpoint is used for logging training data information about surrogate
+
+    Attributes:
+        algo: {Algorithm} -- main algorithm instance reference
+        every: {int} -- checkpoint frequency used (based on number of evaluations)
+        filepath: {str} -- file path where checkpoints will be saved
+    """
+    def run(self):
+        """
+        Check if necessary to do backup based on `every` variable
+        """
+        # get current best solution
+        solution = self._algo._bestSolution
+        surrogate_analyser = self._algo._surrogate_analyser
+
+        # Do nothing is surrogate analyser does not exist
+        if surrogate_analyser is None:
+            return
+
+        currentEvaluation = self._algo.getGlobalEvaluation()
+
+        # backup if necessary
+        if currentEvaluation % self._every == 0:
+
+            logging.info(f"Surrogate analysis checkpoint is done into {self._filepath}")
+
+            solutionData = ""
+            solutionSize = len(solution._data)
+
+            for index, val in enumerate(solution._data):
+                solutionData += str(val)
+
+                if index < solutionSize - 1:
+                    solutionData += ' '
+
+            # get score of r² and mae
+
+            line = str(currentEvaluation) + ';' + str(surrogate_analyser._n_local_search) + ';' + str(surrogate_analyser._every_ls) + ';' + str(surrogate_analyser._time)  + ';' + str(surrogate_analyser._r2) \
+                + ';' + str(surrogate_analyser._mae) \
+                + ';' + solutionData + ';' + str(solution.fitness) + ';\n'
+
+            # check if file exists
+            if not os.path.exists(self._filepath):
+                with open(self._filepath, 'w') as f:
+                    f.write(line)
+            else:
+                with open(self._filepath, 'a') as f:
+                    f.write(line)
+
+    def load(self):
+        """
+        only load global n local search
+        """
+
+        if os.path.exists(self._filepath):
+
+            logging.info('Load n local search')
+            with open(self._filepath) as f:
+
+                # get last line and read data
+                lastline = f.readlines()[-1].replace(';\n', '')
+                data = lastline.split(';')
+
+                n_local_search = int(data[1])
+
+                # set k_indices into main algorithm
+                self._algo._total_n_local_search = n_local_search
+
+            print(macop_line())
+            print(macop_text(f'SurrogateMonoCheckpoint found from `{self._filepath}` file.'))
+
+        else:
+            print(macop_text('No backup found...'))
+            logging.info("Can't load Surrogate backup... Backup filepath not valid in SurrogateCheckpoint")
+
+        print(macop_line())