convolution.cpp 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #include "convolution.hpp"
  2. namespace Layer{
  3. void
  4. Convolution::init(Real m,Real d){
  5. default_random_engine generator;
  6. normal_distribution<Real> distribution(m,d);
  7. for(size_t i=0;i<mf*nf*i*q;++i){
  8. K[i]=distribution(generator);
  9. }
  10. for(size_t i=0;i<mf*nf;++i){
  11. b[i]=distribution(generator);
  12. }
  13. }
  14. Vector
  15. Convolution::feed_forward(Vector x){
  16. x_in=x;
  17. for(size_t g=0;g<mf;++g){
  18. for(size_t k=0;k<mi;++k){
  19. for(size_t l=0;l<mj;++l){
  20. Real temp=0;
  21. for(size_t f=0;f<nf;++f){
  22. for(size_t r=0;r<p;++r){
  23. for(size_t s=0;s<q;++s){
  24. temp+=x[indice3(f,k+r,l+s,ni,nj)]*K[indice4(g,f,r,s,nf,p,q)];
  25. }
  26. }
  27. temp+=b[indice2(g,f,nf)];
  28. }
  29. x_out[indice3(g,k,l,mi,mj)]=temp;
  30. }
  31. }
  32. }
  33. return x_out;
  34. }
  35. void
  36. Convolution::init_nabla(){
  37. for(size_t i=0;i<mf*nf;++i){
  38. nabla_b[i]=0;
  39. }
  40. for(size_t i=0;i<mf*nf*p*q;++i){
  41. nabla_K[i]=0;
  42. }
  43. }
  44. Vector
  45. Convolution::back_propagation(Vector d){
  46. for(size_t f=0;f<nf;++f){
  47. for(size_t i=0;i<ni;++i){
  48. for(size_t j=0;j<nj;++j){
  49. Real temp=0;
  50. for(size_t g=0;g<mf;++g){
  51. size_t r=(i>=mi-1)?i-mi+1:0;
  52. for(;r<min(i,p);++r){
  53. size_t s=(j>=mj-1)?j-mj+1:0;
  54. for(;s<min(j,q);++s){
  55. temp+=K[indice4(g,f,r,s,nf,p,q)]*d[indice3(g,i-r,j-s,mi,mj)];
  56. }//s
  57. }//r
  58. }//g
  59. delta[indice3(f,i,j,ni,nj)]=temp;
  60. }//j
  61. }//i
  62. }//f
  63. //display(delta,nf*ni*nj);
  64. //char a;cin>>a;
  65. //cout<<" - Update nabla_b"<<endl;
  66. //Update nabla_b<<
  67. for(size_t g=0;g<mf;++g){
  68. for(size_t f=0;f<nf;++f){
  69. Real temp=0;
  70. for(size_t k=0;k<mi;++k){
  71. for(size_t l=0;l<mj;++l){
  72. temp+=d[indice3(g,k,l,mi,mj)];
  73. }//l
  74. }//k
  75. nabla_b[indice2(g,f,nf)]+=temp;
  76. }
  77. }
  78. //Update nabla_w
  79. for(size_t g=0;g<mf;++g){
  80. for(size_t f=0;f<nf;++f){
  81. for(size_t r=0;r<p;++r){
  82. for(size_t s=0;s<q;++s){
  83. Real temp=0;
  84. for(size_t k=0;k<mi;++k){
  85. for(size_t l=0;l<mj;++l){
  86. temp+=d[indice3(g,k,l,mi,mj)]*x_in[indice3(f,k+r,l+s,ni,nj)];
  87. }//l
  88. }//k
  89. nabla_K[indice4(g,f,r,s,nf,p,q)]+=temp;
  90. }
  91. }
  92. }
  93. }
  94. return delta;
  95. }
  96. void
  97. Convolution::update(Real eta){
  98. //Update b
  99. for(size_t i=0;i<mf*nf;++i){
  100. b[i]-=eta*nabla_b[i];
  101. }
  102. //Update K
  103. for(size_t i=0;i<mf*nf*p*q;++i){
  104. K[i]-=eta*nabla_K[i];
  105. }
  106. }
  107. }