plan_visualizer.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #!/usr/bin/env python3
  2. ''' quick plan_gen visualizer '''
  3. import sys
  4. import numpy as np
  5. import lxml.etree as etree
  6. import matplotlib
  7. def get_nodes(input):
  8. tree = etree.parse(input)
  9. return [node for node in tree.xpath("/network/nodes/node")]
  10. def get_persons(input):
  11. tree = etree.parse(input)
  12. return [p for p in tree.xpath("/plans/person")]
  13. if __name__ == '__main__':
  14. matplotlib.use('TkAgg')
  15. import matplotlib.pyplot as plt
  16. # get data
  17. NB_CLUSTERS = int(sys.argv[1])
  18. NODES = get_nodes(sys.argv[2])
  19. PERSONS = get_persons(sys.argv[3])
  20. PERSONS_XY = ['{}|{}'.format(
  21. p.find('plan/act').get('x'),
  22. p.find('plan/act').get('y')) for p in PERSONS]
  23. P_XY_UNIQUE, P_XY_COUNTS = np.unique(PERSONS_XY, return_counts=True)
  24. # plot init
  25. FIG = plt.figure()
  26. AX = FIG.add_subplot(111)
  27. # plot nodes
  28. NODES_X = [float(n.get('x')) for n in NODES]
  29. NODES_Y = [float(n.get('y')) for n in NODES]
  30. MIN_X, MAX_X = min(NODES_X), max(NODES_X)
  31. MIN_Y, MAX_Y = min(NODES_Y), max(NODES_Y)
  32. AX.scatter(NODES_X, NODES_Y,
  33. marker='.', c='grey', linewidth=0.5, s=10)
  34. # plot persons
  35. PERSONS_X = [float(coord.split('|')[0]) for coord in P_XY_UNIQUE]
  36. PERSONS_Y = [float(coord.split('|')[1]) for coord in P_XY_UNIQUE]
  37. SC = AX.scatter(PERSONS_X, PERSONS_Y,
  38. alpha=0.75, s=P_XY_COUNTS*5,
  39. c=P_XY_COUNTS, cmap='rainbow')
  40. # plot map
  41. img = plt.imread('input/map.png')
  42. implot = plt.imshow(img, alpha=1, extent=[MIN_X, MAX_X, MIN_Y, MAX_Y])
  43. # final plot
  44. X_TICKS = np.arange(MIN_X, MAX_X, (MAX_X - MIN_X) / NB_CLUSTERS)
  45. Y_TICKS = np.arange(MIN_Y, MAX_Y, (MAX_Y - MIN_Y) / NB_CLUSTERS)
  46. TICK_LABELS = list(range(NB_CLUSTERS))
  47. AX.set_title('Number of agents by nodes (total: {} agents)'.format(len(PERSONS)))
  48. AX.set_xlim(MIN_X, MAX_X)
  49. AX.set_ylim(MIN_Y, MAX_Y)
  50. AX.set_xticks(X_TICKS)
  51. AX.set_yticks(Y_TICKS)
  52. AX.set_xticklabels(TICK_LABELS)
  53. AX.set_yticklabels(TICK_LABELS)
  54. AX.tick_params(labelsize=5)
  55. AX.grid(True)
  56. plt.colorbar(SC)
  57. plt.show()