policies.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. """UCB policy Checkpoint class implementation
  2. """
  3. # main imports
  4. import os
  5. import logging
  6. import numpy as np
  7. # module imports
  8. from .base import Callback
  9. from ..utils.progress import macop_text, macop_line
  10. class UCBCheckpoint(Callback):
  11. """
  12. UCB checkpoint is used for loading previous Upper Confidence Bound data and start again after loading checkpoint
  13. Need to be the same operators used during previous run (see `macop.policies.reinforcement.UCBPolicy` for more details)
  14. Attributes:
  15. algo: {Algorithm} -- main algorithm instance reference
  16. every: {int} -- checkpoint frequency used (based on number of evaluations)
  17. filepath: {str} -- file path where checkpoints will be saved
  18. """
  19. def run(self):
  20. """
  21. Check if necessary to do backup based on `every` variable
  22. """
  23. # get current population
  24. currentEvaluation = self._algo.getGlobalEvaluation()
  25. # backup if necessary
  26. if currentEvaluation % self._every == 0:
  27. logging.info("UCB Checkpoint is done into " + self._filepath)
  28. with open(self._filepath, 'w') as f:
  29. rewardsLine = ''
  30. for i, r in enumerate(self._algo._policy._rewards):
  31. rewardsLine += str(r)
  32. if i != len(self._algo._policy._rewards) - 1:
  33. rewardsLine += ';'
  34. f.write(rewardsLine + '\n')
  35. occurrencesLine = ''
  36. for i, o in enumerate(self._algo._policy._occurences):
  37. occurrencesLine += str(o)
  38. if i != len(self._algo._policy._occurences) - 1:
  39. occurrencesLine += ';'
  40. f.write(occurrencesLine + '\n')
  41. def load(self):
  42. """
  43. Load backup lines as rewards and occurrences for UCB
  44. """
  45. if os.path.exists(self._filepath):
  46. logging.info('Load UCB data')
  47. with open(self._filepath) as f:
  48. lines = f.readlines()
  49. # read data for each line
  50. rewardsLine = lines[0].replace('\n', '')
  51. occurrencesLine = lines[1].replace('\n', '')
  52. self._algo._policy._rewards = [
  53. float(f) for f in rewardsLine.split(';')
  54. ]
  55. self._algo._policy._occurences = [
  56. float(f) for f in occurrencesLine.split(';')
  57. ]
  58. macop_text(self._algo, f'Load of available UCB policy data from `{self._filepath}`')
  59. else:
  60. macop_text(self._algo, 'No UCB data found, use default UCB policy')
  61. logging.info("No UCB data found...")
  62. macop_line(self._algo)