pooling.hpp 744 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #ifndef POOLING_LAYER_HPP
  2. #define POOLING_LAYER_HPP
  3. #include "layer.hpp"
  4. namespace Layer{
  5. class Pooling:public Layer{
  6. public:
  7. size_t nf;
  8. size_t ni;
  9. size_t nj;
  10. size_t p;
  11. size_t q;
  12. size_t mi;
  13. size_t mj;
  14. public:
  15. Pooling(size_t nf,size_t ni,size_t nj,size_t p,size_t q);
  16. ~Pooling(){};
  17. Vector feed_forward(Vector x) override;
  18. void init_nabla() override {};
  19. Vector back_propagation(Vector e) override;
  20. void update(Real) override {};
  21. };
  22. inline
  23. Pooling::Pooling(size_t nf_,size_t ni_,size_t nj_,size_t p_,size_t q_):
  24. Layer(nf_*ni_*nj_,nf_*((ni_+p_-1)/p_)*((nj_+q_-1)/q_)){
  25. nf=nf_;
  26. ni=ni_;
  27. nj=nj_;
  28. p=p_;
  29. q=q_;
  30. mi=(ni+p-1)/p;
  31. mj=(nj+q-1)/p;
  32. }
  33. }
  34. #endif