Parcourir la source

pdevs/mpi: generalize model

Eric Ramat il y a 8 ans
Parent
commit
dbb5924086
4 fichiers modifiés avec 70 ajouts et 41 suppressions
  1. 2 0
      .gitignore
  2. 30 14
      src/tests/mpi/graph_manager.hpp
  3. 17 6
      src/tests/mpi/main.cpp
  4. 21 21
      src/tests/pdevs/models.hpp

+ 2 - 0
.gitignore

@@ -1,2 +1,4 @@
 *~
 build/
+output/
+src/tests/corsen/

+ 30 - 14
src/tests/mpi/graph_manager.hpp

@@ -35,6 +35,8 @@
 #include <paradevs/kernel/pdevs/mpi/GraphManager.hpp>
 #include <paradevs/kernel/pdevs/Simulator.hpp>
 
+#include <sstream>
+
 // #include <tests/boost_graph/graph_defs.hpp>
 
 namespace paradevs { namespace tests { namespace mpi {
@@ -125,10 +127,12 @@ private:
 
 struct RootGraphManagerParameters
 {
-    int S1_rank;
-    int S2_rank;
+    std::vector < int > ranks;
 };
 
+typedef paradevs::pdevs::mpi::ModelProxy < common::DoubleTime > ModelProxy;
+typedef std::vector < ModelProxy* > ModelProxies;
+
 class RootGraphManager :
         public paradevs::pdevs::GraphManager < common::DoubleTime,
                                                RootGraphManagerParameters >
@@ -139,24 +143,36 @@ public:
         const RootGraphManagerParameters& parameters) :
         paradevs::pdevs::GraphManager < common::DoubleTime,
                                         RootGraphManagerParameters >(
-                                            coordinator, parameters),
-        S1("a", parameters.S1_rank, false),
-        S2("b", parameters.S2_rank, false)
+                                            coordinator, parameters)
     {
-        add_child(&S1);
-        add_child(&S2);
-
-        S1.add_out_port("out");
-        S2.add_in_port("in");
-        add_link(&S1, "out", &S2, "in");
+        ModelProxy* previous = 0;
+
+        for (std::vector < int >::const_iterator it = parameters.ranks.begin();
+             it != parameters.ranks.end(); ++it) {
+            std::stringstream ss;
+            ModelProxy* model = 0;
+
+            ss << "S" << *it;
+            model = new ModelProxy(ss.str(), *it, false);
+            models.push_back(model);
+            add_child(model);
+            model->add_out_port("out");
+            model->add_in_port("in");
+            if (it != parameters.ranks.begin()) {
+                add_link(previous, "out", model, "in");
+            }
+            previous = model;
+        }
     }
 
     virtual ~RootGraphManager()
-    { }
+    {
+        std::for_each(models.begin(), models.end(),
+                      std::default_delete < ModelProxy >());
+    }
 
 private:
-    paradevs::pdevs::mpi::ModelProxy < common::DoubleTime > S1;
-    paradevs::pdevs::mpi::ModelProxy < common::DoubleTime > S2;
+    ModelProxies models;
 };
 
 // struct MPIHierarchicalGraphManagerParameters

+ 17 - 6
src/tests/mpi/main.cpp

@@ -53,8 +53,16 @@ void example_simple(int argc, char *argv[])
     if (world.rank() == 0) {
         paradevs::tests::mpi::RootGraphManagerParameters parameters;
 
-        parameters.S1_rank = 1;
-        parameters.S2_rank = 2;
+        parameters.ranks.push_back(1);
+        parameters.ranks.push_back(2);
+        parameters.ranks.push_back(3);
+        parameters.ranks.push_back(4);
+        parameters.ranks.push_back(5);
+        parameters.ranks.push_back(6);
+        parameters.ranks.push_back(7);
+        parameters.ranks.push_back(8);
+        parameters.ranks.push_back(9);
+        parameters.ranks.push_back(10);
 
         paradevs::common::RootCoordinator <
             DoubleTime,
@@ -63,26 +71,29 @@ void example_simple(int argc, char *argv[])
                 paradevs::tests::mpi::RootGraphManager,
                 paradevs::common::NoParameters,
                 paradevs::tests::mpi::RootGraphManagerParameters >
-            > rc(0, 10, "root", paradevs::common::NoParameters(), parameters);
+            > rc(0, 20, "root", paradevs::common::NoParameters(), parameters);
 
         rc.run();
     } else {
+        std::stringstream ss;
+
+        ss << "S" << world.rank();
         if (world.rank() == 1) {
             paradevs::pdevs::mpi::Coordinator <
                 DoubleTime,
                 paradevs::tests::mpi::S1GraphManager > model(
-                    "S1", paradevs::common::NoParameters(),
+                    ss.str(), paradevs::common::NoParameters(),
                     paradevs::common::NoParameters());
             paradevs::pdevs::mpi::LogicalProcessor <
                 DoubleTime > LP(&model, world.rank(), 0);
 
             model.set_logical_processor(&LP);
             LP.loop();
-        } else if (world.rank() == 2) {
+        } else {
             paradevs::pdevs::mpi::Coordinator <
                 DoubleTime,
                 paradevs::tests::mpi::S2GraphManager > model(
-                    "S2", paradevs::common::NoParameters(),
+                    ss.str(), paradevs::common::NoParameters(),
                     paradevs::common::NoParameters());
             paradevs::pdevs::mpi::LogicalProcessor <
                 DoubleTime > LP(&model, world.rank(), 0);

+ 21 - 21
src/tests/pdevs/models.hpp

@@ -91,8 +91,8 @@ public:
         ++_value.x;
         --_value.y;
 
-        std::cout << t << ": " << get_name() << " => dint -> "
-                  << _value.x << " " << _value.y << std::endl;
+        // std::cout << t << ": " << get_name() << " => dint -> "
+        //           << _value.x << " " << _value.y << std::endl;
 
         if (_phase == SEND) {
             _phase = WAIT;
@@ -109,12 +109,12 @@ public:
         (void)msgs;
 #endif
 
-        for (common::Bag < common::DoubleTime >::const_iterator it =
-                 msgs.begin(); it != msgs.end(); ++it) {
-            std::cout << t << ": " << get_name()
-                      << " => " << it->get_content().get_content < double >()
-                      << std::endl;
-        }
+        // for (common::Bag < common::DoubleTime >::const_iterator it =
+        //          msgs.begin(); it != msgs.end(); ++it) {
+        //     std::cout << t << ": " << get_name()
+        //               << " => " << it->get_content().get_content < double >()
+        //               << std::endl;
+        // }
 
 #ifdef WITH_TRACE
         common::Trace < common::DoubleTime >::trace()
@@ -138,8 +138,8 @@ public:
         (void)msgs;
 #endif
 
-        std::cout << t << ": " << get_name() << " => " << msgs.to_string()
-                  << std::endl;
+        // std::cout << t << ": " << get_name() << " => " << msgs.to_string()
+        //           << std::endl;
 
 #ifdef WITH_TRACE
         common::Trace < common::DoubleTime >::trace()
@@ -206,7 +206,7 @@ public:
             common::ExternalEvent < common::DoubleTime >(
                 "out", common::Value(_value)));
 
-        std::cout << t << ": " << get_name() << " => lambda" << std::endl;
+        // std::cout << t << ": " << get_name() << " => lambda" << std::endl;
 
 #ifdef WITH_TRACE
         common::Trace < common::DoubleTime >::trace()
@@ -259,8 +259,8 @@ public:
         delay();
         ++_value;
 
-        std::cout << t << ": " << get_name() << " => dint -> "
-                  << _value << std::endl;
+        // std::cout << t << ": " << get_name() << " => dint -> "
+        //           << _value << std::endl;
 
         if (_phase == SEND) {
             _phase = WAIT;
@@ -277,13 +277,13 @@ public:
         (void)msgs;
 #endif
 
-        for (common::Bag < common::DoubleTime >::const_iterator it =
-                 msgs.begin(); it != msgs.end(); ++it) {
-            std::cout << t << ": " << get_name()
-                      << " => " << it->get_content().get_content < data >().x
-                      << " " << it->get_content().get_content < data >().y
-                      << std::endl;
-        }
+        // for (common::Bag < common::DoubleTime >::const_iterator it =
+        //          msgs.begin(); it != msgs.end(); ++it) {
+        //     std::cout << t << ": " << get_name()
+        //               << " => " << it->get_content().get_content < data >().x
+        //               << " " << it->get_content().get_content < data >().y
+        //               << std::endl;
+        // }
 
 #ifdef WITH_TRACE
         common::Trace < common::DoubleTime >::trace()
@@ -368,7 +368,7 @@ public:
 #endif
         common::Bag < common::DoubleTime > msgs;
 
-        std::cout << t << ": " << get_name() << " => lambda" << std::endl;
+        // std::cout << t << ": " << get_name() << " => lambda" << std::endl;
 
         msgs.push_back(common::ExternalEvent <
                            common::DoubleTime >(