Parcourir la source

Fix: compare operator for heap

Eric Ramat il y a 11 ans
Parent
commit
e9bfc438a8

+ 10 - 2
src/common/InternalEvent.hpp

@@ -37,7 +37,7 @@ template < class Time >
 class Model;
 
 template < class Time >
-class InternalEvent
+struct InternalEvent
 {
 public:
     InternalEvent(const typename Time::type& time, Model < Time >* model)
@@ -63,11 +63,19 @@ public:
         return _time > e._time;
     }
 
+    bool operator>=(InternalEvent const &e) const
+    {
+        return _time >= e._time;
+    }
+
     bool operator==(InternalEvent const &e) const
     {
         return _time == e._time;
     }
 
+    void set_time(typename Time::type time)
+    { _time = time; }
+
 private:
     typename Time::type _time;
     Model < Time >*     _model;
@@ -78,7 +86,7 @@ struct EventCompare
     : std::binary_function < Event, Event, bool >
 {
     bool operator()(const Event &left, const Event &right) const
-    { return left > right; }
+    { return left >= right; }
 };
 
 } } // namespace paradevs common

+ 2 - 2
src/common/Model.hpp

@@ -178,11 +178,11 @@ public:
                              Time > > > >::handle_type id)
     { _heap_id = id; }
 
-    typename boost::heap::fibonacci_heap <
+    const typename boost::heap::fibonacci_heap <
         InternalEvent < Time >,
         boost::heap::compare <
             EventCompare < InternalEvent <
-                               Time > > > >::handle_type heap_id()
+                               Time > > > >::handle_type& heap_id() const
                                { return _heap_id; }
 
 protected:

+ 25 - 16
src/common/scheduler/HeapScheduler.hpp

@@ -51,10 +51,14 @@ struct Heap
 };
 
 template < class Time >
-class HeapScheduler : public boost::heap::fibonacci_heap <
-    InternalEvent < Time >, boost::heap::compare <
-                                EventCompare < InternalEvent < Time > > > >
+class HeapScheduler
 {
+    typedef boost::heap::fibonacci_heap < InternalEvent < Time >,
+                                          boost::heap::compare <
+                                              EventCompare <
+                                                  InternalEvent < Time > > >
+                                          > Heap;
+
 public:
     HeapScheduler()
     { }
@@ -63,17 +67,15 @@ public:
 
     Model < Time >* get_current_model()
     {
-        return HeapScheduler < Time >::top().get_model();
+        return _heap.top().get_model();
     }
 
     Models < Time > get_current_models(typename Time::type time) const
     {
         Models < Time > models;
 
-        for (typename HeapScheduler < Time >::ordered_iterator it =
-                 HeapScheduler < Time >::ordered_begin();
-             it != HeapScheduler < Time >::ordered_end() and
-                 it->get_time() == time; ++it) {
+        for (typename Heap::ordered_iterator it = _heap.ordered_begin();
+             it != _heap.ordered_end() and it->get_time() == time; ++it) {
             models.push_back(it->get_model());
         }
         return models;
@@ -81,19 +83,24 @@ public:
 
     typename Time::type get_current_time() const
     {
-        return HeapScheduler < Time >::top().get_time();
+        return _heap.top().get_time();
     }
 
     void init(typename Time::type time, Model < Time >* model)
     {
-        model->heap_id(HeapScheduler < Time >::push(
-                           InternalEvent < Time >(time, model)));
+        model->heap_id(_heap.push(InternalEvent < Time >(time, model)));
     }
 
     void put(typename Time::type time, Model < Time >* model)
     {
-        HeapScheduler < Time >::update(
-            model->heap_id(), InternalEvent < Time >(time, model));
+        typename Time::type previous_time = (*model->heap_id()).get_time();
+
+        (*model->heap_id()).set_time(time);
+        if (previous_time < time) {
+            _heap.decrease(model->heap_id());
+        } else if (previous_time > time) {
+            _heap.increase(model->heap_id());
+        }
     }
 
     std::string to_string() const
@@ -101,15 +108,17 @@ public:
         std::stringstream ss;
 
         ss << "Scheduler = { ";
-        for (typename HeapScheduler < Time >::ordered_iterator it =
-                 HeapScheduler < Time >::ordered_begin();
-             it != HeapScheduler < Time >::ordered_end(); ++it) {
+        for (typename Heap::ordered_iterator it = _heap.ordered_begin();
+             it != _heap.ordered_end(); ++it) {
             ss << "(" << it->get_time() << " -> " << it->get_model()->get_name()
                << ") ";
         }
         ss << "}";
         return ss.str();
     }
+
+private:
+    Heap _heap;
 };
 
 } } } // namespace paradevs common scheduler

+ 72 - 1
src/tests/boost_graph/tests.cpp

@@ -59,6 +59,77 @@ void hierarchical_test()
 int main()
 {
     // flat_test();
-    hierarchical_test();
+    // hierarchical_test();
+
+    class M;
+
+    struct A
+    {
+        double time;
+        M*     model;
+
+        A(double _time, M* _model)
+        {
+            time = _time;
+            model = _model;
+        }
+    };
+
+    struct ACompare
+        : std::binary_function < A, A, bool >
+    {
+        bool operator()(const A &left, const A &right) const
+        { return left.time > right.time; }
+    };
+
+    typedef boost::heap::fibonacci_heap < A, boost::heap::compare <
+        ACompare > > Heap;
+
+    typedef Heap::handle_type HeapHandle;
+
+    class M
+    {
+    public:
+        M(int a)
+        {
+            _a = a;
+        }
+
+        int a() const
+        { return _a; }
+
+        HeapHandle heap_id() const
+        { return _heap_id; }
+
+        void heap_id(HeapHandle id)
+        { _heap_id = id; }
+
+    private:
+        int _a;
+        HeapHandle _heap_id;
+    };
+
+    Heap heap;
+    M* m1 = new M(1);
+    M* m2 = new M(2);
+
+    m1->heap_id(heap.push(A(0, m1)));
+    m2->heap_id(heap.push(A(0, m2)));
+
+    (*m1->heap_id()).time = 1;
+    heap.decrease(m1->heap_id());
+    (*m2->heap_id()).time = 1;
+    heap.decrease(m2->heap_id());
+
+    std::cout << "Scheduler = { ";
+    while (not heap.empty()) {
+        std::cout << "(" << heap.top().time << "," << heap.top().model->a()
+                  << ") ";
+        heap.pop();
+    }
+    std::cout << "}" << std::endl;
+
+
+
     return 0;
 }

+ 3 - 3
src/tests/mixed/graph_manager.hpp

@@ -162,7 +162,7 @@ public:
 
 private:
     pdevs::Coordinator < MyTime,
-                         paradevs::common::scheduler::VectorScheduler <
+                         paradevs::common::scheduler::HeapScheduler <
                              MyTime >,
                          S1GraphManager,
                          paradevs::common::NoParameters,
@@ -322,13 +322,13 @@ public:
 
 private:
     pdevs::Coordinator < MyTime,
-                         paradevs::common::scheduler::VectorScheduler <
+                         paradevs::common::scheduler::HeapScheduler <
                              MyTime >,
                          Linear2GraphManager,
                          paradevs::common::NoParameters,
                          paradevs::common::NoParameters > S1;
     pdevs::Coordinator < MyTime,
-                         paradevs::common::scheduler::VectorScheduler <
+                         paradevs::common::scheduler::HeapScheduler <
                              MyTime >,
                          Linear2GraphManager,
                          paradevs::common::NoParameters,

+ 1 - 1
src/tests/mixed/tests.cpp

@@ -38,7 +38,7 @@ TEST_CASE("mixed/hierachical", "run")
 {
     paradevs::common::RootCoordinator <
         MyTime, paradevs::pdevs::Coordinator <
-            MyTime, paradevs::common::scheduler::VectorScheduler <
+            MyTime, paradevs::common::scheduler::HeapScheduler <
                 MyTime >, RootGraphManager > > rc(0, 100, "root");
 
     paradevs::common::Trace < MyTime >::trace().clear();

+ 3 - 2
src/tests/pdevs/graph_manager.hpp

@@ -29,6 +29,7 @@
 
 #include <tests/pdevs/models.hpp>
 
+#include <common/scheduler/HeapScheduler.hpp>
 #include <common/scheduler/VectorScheduler.hpp>
 #include <common/Trace.hpp>
 
@@ -119,10 +120,10 @@ public:
 
 private:
     paradevs::pdevs::Coordinator <
-    MyTime, paradevs::common::scheduler::VectorScheduler < MyTime >,
+    MyTime, paradevs::common::scheduler::HeapScheduler < MyTime >,
     S1GraphManager > S1;
     paradevs::pdevs::Coordinator <
-        MyTime, paradevs::common::scheduler::VectorScheduler < MyTime >,
+        MyTime, paradevs::common::scheduler::HeapScheduler < MyTime >,
         S2GraphManager > S2;
 };
 

+ 1 - 1
src/tests/pdevs/tests.cpp

@@ -188,7 +188,7 @@ TEST_CASE("pdevs/hierachical", "run")
 {
     paradevs::common::RootCoordinator <
         MyTime, paradevs::pdevs::Coordinator <
-            MyTime, paradevs::common::scheduler::VectorScheduler <
+            MyTime, paradevs::common::scheduler::HeapScheduler <
                 MyTime >, RootGraphManager >
         > rc(0, 10, "root", paradevs::common::NoParameters(),
              paradevs::common::NoParameters());