pooling.cpp 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #include "pooling.hpp"
  2. namespace Layer{
  3. Vector
  4. Pooling::feed_forward(Vector x){
  5. x_in=x;
  6. for(size_t f=0;f<nf;++f){
  7. for(size_t k=0;k<mi;++k){
  8. for(size_t l=0;l<mi;++l){
  9. size_t i=p*k;
  10. size_t j=q*l;
  11. Real temp=x[indice3(f,i,j,ni,nj)];
  12. for(;i<min(p*k+p,ni);++i){
  13. for(;j<min(q*l+q,nj);++j){
  14. temp=max(temp,x[indice3(f,i,j,ni,nj)]);
  15. }
  16. }
  17. x_out[indice3(f,k,l,mi,mj)]=temp;
  18. }
  19. }
  20. }
  21. return x_out;
  22. }
  23. Vector
  24. Pooling::back_propagation(Vector e){
  25. for(size_t f=0;f<nf;++f){
  26. for(size_t i=0;i<ni;++i){
  27. size_t k=i/p;
  28. for(size_t j=0;j<nj;++j){
  29. size_t l=j/q;
  30. bool is_max=true;
  31. Real val=x_in[indice3(f,i,j,ni,nj)];
  32. for(size_t i2=k*p;i2<min(k*p+p,ni) and is_max;++i2){
  33. for(size_t j2=l*q;j2<min(l*q+q,nj) and is_max;++j2){
  34. if(x_in[indice3(f,i2,j2,ni,nj)]>val) is_max=false;
  35. }
  36. }
  37. d[indice3(f,i,j,ni,nj)]=(is_max)?e[indice3(f,k,l,mi,mj)]:0;
  38. }
  39. }
  40. }
  41. return d;
  42. }
  43. }