|
@@ -168,6 +168,7 @@ def main():
|
|
|
parser.add_argument('--length', type=int, help='max data length (need to be specify for evaluator)', required=True)
|
|
|
parser.add_argument('--ils', type=int, help='number of total iteration for ils algorithm', required=True)
|
|
|
parser.add_argument('--ls', type=int, help='number of iteration for Local Search algorithm', required=True)
|
|
|
+ parser.add_argument('--every_ls', type=int, help='number of max iteration for retraining surrogate model', required=True)
|
|
|
parser.add_argument('--output', type=str, help='output surrogate model name')
|
|
|
|
|
|
args = parser.parse_args()
|
|
@@ -177,6 +178,7 @@ def main():
|
|
|
p_start = args.start_surrogate
|
|
|
p_ils_iteration = args.ils
|
|
|
p_ls_iteration = args.ls
|
|
|
+ p_every_ls = args.every_ls
|
|
|
p_output = args.output
|
|
|
|
|
|
print(p_data_file)
|
|
@@ -279,7 +281,7 @@ def main():
|
|
|
_surrogate_file_path=surrogate_output_model,
|
|
|
_start_train_surrogate=p_start,
|
|
|
_solutions_file=surrogate_output_data,
|
|
|
- _ls_train_surrogate=1,
|
|
|
+ _ls_train_surrogate=p_every_ls,
|
|
|
_maximise=True)
|
|
|
|
|
|
algo.addCallback(BasicCheckpoint(_every=1, _filepath=backup_file_path))
|