Florian il y a 6 ans
Parent
commit
23a66d0390
3 fichiers modifiés avec 20 ajouts et 14 suppressions
  1. 11 4
      plan_gen/plan_gen.py
  2. 6 6
      plan_gen/plan_gen_cli.py
  3. 3 4
      plan_gen/plan_visualizer.py

+ 11 - 4
plan_gen/plan_gen.py

@@ -30,7 +30,13 @@ def parse_params(param_str):
     if param_str:
         for key_value_str in param_str.split(','):
             key, value = key_value_str.split('=')
-            dict_params[key] = parse_value(value)
+            if key in ['hc', 'wc']:
+                coords = value.split('|')
+                dict_params[key] = [np.fromstring(str(x), dtype=int, sep=':') for x in coords]
+            elif key in ['hr', 'wr']:
+                dict_params[key] = np.fromstring(str(value), dtype=int, sep='|')
+            else:
+                dict_params[key] = parse_value(value)
     return dict_params
 
 def get_seconds(time_str):
@@ -91,13 +97,14 @@ def rand_time(low, high):
 
 def rand_node_xy(nodes, clusters, densities):
     ''' returns a random node coordinates from a random cluster '''
+    node = None
     clusters = clusters.flatten()
     densities = densities.flatten()
     cluster = np.random.choice(clusters, p=densities/sum(densities))
-    if cluster is not None:
-        node = cluster[np.random.randint(len(cluster))]
-    else:
+    if cluster is None:
         node = nodes[np.random.randint(len(nodes))]
+    else:
+        node = cluster[np.random.randint(len(cluster))]
     return (node.get('x'), node.get('y'))
 
 def rand_person(nodes, clusters, h_dens, w_dens):

+ 6 - 6
plan_gen/plan_gen_cli.py

@@ -6,10 +6,6 @@ import lxml.etree as etree
 import numpy as np
 import plan_gen.plan_gen as pg
 
-# TODO: make these constants as user params
-D_CENTERS = [(27, 40), (32, 8)]
-D_RADIUS = [10, 10]
-
 if __name__ == '__main__':
 
     # command line arguments
@@ -21,12 +17,16 @@ if __name__ == '__main__':
     NB_CLUSTERS = DICT_PARAMS['nc']
     NB_PERSONS = DICT_PARAMS['np']
     INPUT_NETWORK = DICT_PARAMS['nw']
+    H_CENTERS = DICT_PARAMS['hc'] if 'hc' in DICT_PARAMS else None
+    W_CENTERS = DICT_PARAMS['wc'] if 'wc' in DICT_PARAMS else None
+    H_RADIUS = DICT_PARAMS['hr'] if 'hr' in DICT_PARAMS else None
+    W_RADIUS = DICT_PARAMS['wr'] if 'wr' in DICT_PARAMS else None
 
     # prepare data
     NODES = pg.get_nodes(INPUT_NETWORK)
     CLUSTERS = pg.make_clusters(NB_CLUSTERS, NODES)
-    H_DENSITIES = pg.make_densities(NB_CLUSTERS, D_CENTERS, D_RADIUS)
-    W_DENSITIES = pg.make_densities(NB_CLUSTERS)
+    H_DENSITIES = pg.make_densities(NB_CLUSTERS, H_CENTERS, H_RADIUS)
+    W_DENSITIES = pg.make_densities(NB_CLUSTERS, W_CENTERS, W_RADIUS)
 
     # make xml
     PERSONS = [pg.rand_person(NODES, CLUSTERS, H_DENSITIES, W_DENSITIES) for _ in range(NB_PERSONS)]

+ 3 - 4
plan_gen/plan_visualizer.py

@@ -18,11 +18,10 @@ if __name__ == '__main__':
     matplotlib.use('TkAgg')
     import matplotlib.pyplot as plt
 
-    NB_CLUSTERS = 50
-
     # get data
-    NODES = get_nodes(sys.argv[1])
-    PERSONS = get_persons(sys.argv[2])
+    NB_CLUSTERS = int(sys.argv[1])
+    NODES = get_nodes(sys.argv[2])
+    PERSONS = get_persons(sys.argv[3])
     PERSONS_XY = ['{}|{}'.format(
                   p.find('plan/act').get('x'),
                   p.find('plan/act').get('y')) for p in PERSONS]