Parcourir la source

parallelisation

Florian il y a 6 ans
Parent
commit
6c2be2213e
3 fichiers modifiés avec 17 ajouts et 5 suppressions
  1. 12 1
      plan_gen/plan_gen_cli.py
  2. 2 3
      plan_gen/plan_visualizer.py
  3. 3 1
      setup.py

+ 12 - 1
plan_gen/plan_gen_cli.py

@@ -4,8 +4,17 @@
 import sys
 import lxml.etree as etree
 import numpy as np
+import joblib as jl
 import plan_gen.plan_gen as pg
 
+NB_CORES = 4
+
+def run_rand_person(i):
+    ''' parallelisation '''
+    np.random.seed(SEEDS[i])
+    return [pg.rand_person(NODES, CLUSTERS, H_DENSITIES, W_DENSITIES)
+                for _ in range(int(NB_PERSONS/NB_CORES))]
+
 if __name__ == '__main__':
 
     # command line arguments
@@ -29,7 +38,9 @@ if __name__ == '__main__':
     W_DENSITIES = pg.make_densities(NB_CLUSTERS, W_CENTERS, W_RADIUS)
 
     # make xml
-    PERSONS = [pg.rand_person(NODES, CLUSTERS, H_DENSITIES, W_DENSITIES) for _ in range(NB_PERSONS)]
+    SEEDS = np.random.random_integers(0, 1e9, NB_CORES)
+    PERSONS = jl.Parallel(n_jobs=NB_CORES)(jl.delayed(run_rand_person)(i) for i in range(NB_CORES))
+    PERSONS = sum(PERSONS,[])
     PLANS = pg.make_plans(PERSONS)
 
     # print XML

+ 2 - 3
plan_gen/plan_visualizer.py

@@ -37,14 +37,14 @@ if __name__ == '__main__':
     MIN_X, MAX_X = min(NODES_X), max(NODES_X)
     MIN_Y, MAX_Y = min(NODES_Y), max(NODES_Y)
     AX.scatter(NODES_X, NODES_Y,
-                marker='.', c='grey', linewidth=0.5, s=10, label='node')
+                marker='.', c='grey', linewidth=0.5, s=10)
 
     # plot persons
     PERSONS_X = [float(coord.split('|')[0]) for coord in P_XY_UNIQUE]
     PERSONS_Y = [float(coord.split('|')[1]) for coord in P_XY_UNIQUE]
     SC = AX.scatter(PERSONS_X, PERSONS_Y,
                      alpha=0.75, s=P_XY_COUNTS*10,
-                     c=P_XY_COUNTS, cmap='YlOrRd', label='number of agents')
+                     c=P_XY_COUNTS, cmap='rainbow')
 
     # plot map
     img = plt.imread('input/map.png')
@@ -63,5 +63,4 @@ if __name__ == '__main__':
     AX.set_yticklabels(TICK_LABELS)
     AX.grid(True)
     plt.colorbar(SC)
-    plt.legend()
     plt.show()

+ 3 - 1
setup.py

@@ -8,7 +8,9 @@ setup(
     packages = ['plan_gen'],
     install_requires = [
         'numpy', 
-        'lxml'
+        'lxml',
+        'matplotlib',
+        'joblib'
     ],
 )