Parcourir la source

new version of surrogate optimization

Jérôme BUISINE il y a 3 ans
Parent
commit
0dab2b9257
1 fichiers modifiés avec 21 ajouts et 7 suppressions
  1. 21 7
      optimization/ILSPopSurrogate.py

+ 21 - 7
optimization/ILSPopSurrogate.py

@@ -6,6 +6,8 @@ import os
 import logging
 import joblib
 import time
+import pandas as pd
+from sklearn.utils import shuffle
 
 # module imports
 from macop.algorithms.base import Algorithm
@@ -116,18 +118,30 @@ class ILSPopSurrogate(Algorithm):
         analysis = FitterAnalysis(logfile="train_surrogate.log", problem=problem)
         algo = FitterAlgo(problem=problem, surrogate=surrogate, analysis=analysis, seed=problem.seed)
 
-        # dynamic number of samples based on dataset real evaluations
-        nsamples = None
-        with open(self._solutions_file, 'r') as f:
-            nsamples = len(f.readlines()) - 1 # avoid header
+        # data set
+        df = pd.read_csv(self._solutions_file, sep=';')
+        
+        # learning set and test set based on max last 1000 samples
+        max_samples = 1000
 
-        training_samples = int(0.7 * nsamples) # 70% used for learning part at each iteration
+        if df.x.count() < max_samples:
+            max_samples = df.x.count()
+
+        ntraining_samples = max_samples * 0.80
         
+        # extract reduced dataset if necessary
+        reduced_df = df.tail(max_samples)
+        reduced_df = shuffle(reduced_df)
+
+        # shuffle dataset
+        learn = reduced_df.iloc[0:ntraining_samples]
+        test = reduced_df.drop[learn.index]
+
         print("Start fitting again the surrogate model")
-        print(f'Using {training_samples} of {nsamples} samples for train dataset')
+        print(f'Using {ntraining_samples} samples of {max_samples} for train dataset')
         for r in range(10):
             print(f"Iteration n°{r}: for fitting surrogate")
-            algo.run(samplefile=self._solutions_file, sample=training_samples, step=10)
+            algo.run_samples(learn=learn, test=test, step=10)
 
         joblib.dump(algo, self._surrogate_file_path)