UCBCheckpoint.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """UCB policy Checkpoint class implementation
  2. """
  3. # main imports
  4. import os
  5. import logging
  6. import numpy as np
  7. import pkgutil
  8. # module imports
  9. from .Callback import Callback
  10. from ..utils.color import macop_text, macop_line
  11. class UCBCheckpoint(Callback):
  12. """
  13. UCB checkpoint is used for loading previous UCB data and start again after loading checkpoint
  14. Need to be the same operators used during previous run
  15. Attributes:
  16. algo: {Algorithm} -- main algorithm instance reference
  17. every: {int} -- checkpoint frequency used (based on number of evaluations)
  18. filepath: {str} -- file path where checkpoints will be saved
  19. """
  20. def run(self):
  21. """
  22. Check if necessary to do backup based on `every` variable
  23. """
  24. # get current population
  25. currentEvaluation = self.algo.getGlobalEvaluation()
  26. # backup if necessary
  27. if currentEvaluation % self.every == 0:
  28. logging.info("UCB Checkpoint is done into " + self.filepath)
  29. with open(self.filepath, 'w') as f:
  30. rewardsLine = ''
  31. for i, r in enumerate(self.algo.policy.rewards):
  32. rewardsLine += str(r)
  33. if i != len(self.algo.policy.rewards) - 1:
  34. rewardsLine += ';'
  35. f.write(rewardsLine + '\n')
  36. occurrencesLine = ''
  37. for i, o in enumerate(self.algo.policy.occurrences):
  38. occurrencesLine += str(o)
  39. if i != len(self.algo.policy.occurrences) - 1:
  40. occurrencesLine += ';'
  41. f.write(occurrencesLine + '\n')
  42. def load(self):
  43. """
  44. Load backup lines as rewards and occurrences for UCB
  45. """
  46. if os.path.exists(self.filepath):
  47. logging.info('Load UCB data')
  48. with open(self.filepath) as f:
  49. lines = f.readlines()
  50. # read data for each line
  51. rewardsLine = lines[0].replace('\n', '')
  52. occurrencesLine = lines[1].replace('\n', '')
  53. self.algo.policy.rewards = [
  54. float(f) for f in rewardsLine.split(';')
  55. ]
  56. self.algo.policy.occurrences = [
  57. float(f) for f in occurrencesLine.split(';')
  58. ]
  59. print(
  60. macop_text(
  61. 'Load of available UCB policy data from `{}`'.format(
  62. self.filepath)))
  63. else:
  64. print(macop_text('No UCB data found, use default UCB policy'))
  65. logging.info("No UCB data found...")
  66. print(macop_line())