Parcourir la source

update use of MAE indicator

Jérôme BUISINE il y a 3 ans
Parent
commit
888091c57d

+ 1 - 1
optimization/ILSMultiSurrogate.py

@@ -397,11 +397,11 @@ class ILSMultiSurrogate(Algorithm):
             mae_score = sum(mae_scores) / len(mae_scores)
 
             training_surrogate_every = int(abs(r_squared) * self._ls_train_surrogates) # use of absolute value for r²
-            print(f"=> R² of surrogate is of {r_squared} | MAE is of {mae_score} -- [Retraining model after {self._n_local_search % training_surrogate_every} of {training_surrogate_every} LS]")
 
             # avoid issue when lauching every each local search
             if training_surrogate_every <= 0:
                 training_surrogate_every = 1
+            print(f"=> R² of surrogate is of {r_squared} | MAE is of {mae_score} -- [Retraining model after {self._n_local_search % training_surrogate_every} of {training_surrogate_every} LS]")
 
             # check if necessary or not to train again surrogate
             if self._n_local_search % training_surrogate_every == 0 and self._start_train_surrogates <= self.getGlobalEvaluation():

+ 1 - 1
optimization/callbacks/SurrogateCheckpoint.py

@@ -53,7 +53,7 @@ class SurrogateCheckpoint(Callback):
             mae_data = ' '.join(list(map(str, surrogate_analyser._mae_scores)))
 
             line = str(currentEvaluation) + ';' + str(surrogate_analyser._n_local_search) + ';' + str(surrogate_analyser._every_ls) + ';' + str(surrogate_analyser._time) + ';' + r2_data + ';' + str(surrogate_analyser._r2) \
-                + ';' + mae_data + ';' + surrogate_analyser._mae \
+                + ';' + mae_data + ';' + str(surrogate_analyser._mae) \
                 + ';' + solutionData + ';' + str(solution.fitness()) + ';\n'
 
             # check if file exists