convolution.hpp 812 B

1234567891011121314151617181920212223242526272829303132333435
  1. #ifndef CONVOLUTION_LAYER_HPP
  2. #define CONVOLUTION_LAYER_HPP
  3. #include <random>
  4. #include "layer.hpp"
  5. #include <cstdint>
  6. #include "avx.hpp"
  7. namespace Layer{
  8. /*****************************************
  9. * Implementation of a convolutionnal Layer
  10. */
  11. class ConvolutionLayer:public Layer{
  12. private:
  13. size_t nf,ni,nj;
  14. size_t mf,mi,mj;
  15. size_t p,q;
  16. Vector K;
  17. Vector b;
  18. Vector nabla_K;
  19. Vector nabla_b;
  20. v8i* vindex1;
  21. public:
  22. ConvolutionLayer(size_t nf,size_t ni,size_t nj,size_t p,size_t q,size_t mf);
  23. ~ConvolutionLayer();
  24. void init(Real mu,Real sigma);
  25. Vector feed_forward(Vector x) override;
  26. Vector avx_feed_forward(Vector x);
  27. void init_nabla() override;
  28. Vector back_propagation(Vector e) override;
  29. void update(Real eta) override;
  30. };
  31. }
  32. #endif