full_connected.cpp 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #include "full_connected.hpp"
  2. namespace Layer{
  3. FullConnectedLayer::FullConnectedLayer(size_t n,size_t m):Layer(n,m){
  4. b=init_vector(m); //b
  5. w=init_vector(m*n); //w
  6. nabla_b=init_vector(m); //nabla_b
  7. nabla_w=init_vector(m*n); //nabla_w
  8. }
  9. FullConnectedLayer::~FullConnectedLayer(){
  10. delete_vector(b);
  11. delete_vector(w);
  12. delete_vector(nabla_b);
  13. delete_vector(nabla_w);
  14. }
  15. void
  16. FullConnectedLayer::init(Real mu,Real sigma){
  17. default_random_engine generator;
  18. normal_distribution<Real> distribution(mu,sigma);
  19. for(size_t i=0;i<m;++i){
  20. b[i]=distribution(generator);
  21. }
  22. for(size_t i=0;i<m*n;++i){
  23. w[i]=distribution(generator);
  24. }
  25. }
  26. void
  27. FullConnectedLayer::init_standard(){
  28. default_random_engine generator;
  29. normal_distribution<Real> distribution(0,1);
  30. for(size_t i=0;i<m;++i){
  31. b[i]=distribution(generator);
  32. }
  33. for(size_t i=0;i<m*n;++i){
  34. normal_distribution<Real> distribution2(0,1/sqrt(n));
  35. w[i]=distribution2(generator);
  36. }
  37. }
  38. Vector
  39. FullConnectedLayer::feed_forward(Vector x_){
  40. x=x_;
  41. for(size_t i=0;i<m;++i){
  42. Real temp=b[i];
  43. for(size_t j=0;j<n;++j){
  44. temp+=w[indice2(i,j,n)]*x[j];
  45. }
  46. y[i]=temp;
  47. }
  48. return y;
  49. }
  50. void
  51. FullConnectedLayer::init_nabla(){
  52. for(size_t i=0;i<m;++i){
  53. nabla_b[i]=0;
  54. }
  55. for(size_t i=0;i<m*n;++i){
  56. nabla_w[i]=0;
  57. }
  58. }
  59. Vector
  60. FullConnectedLayer::back_propagation(Vector e){
  61. for(size_t i=0;i<n;++i){
  62. Real temp=0;
  63. for(size_t j=0;j<m;++j){
  64. temp+=w[indice2(j,i,n)]*e[j];
  65. }
  66. d[i]=temp;
  67. }
  68. //Update nabla_b
  69. for(size_t i=0;i<m;++i){
  70. nabla_b[i]+=e[i];
  71. }
  72. //Update nabla_w
  73. for(size_t i=0;i<m;++i){
  74. for(size_t j=0;j<n;++j){
  75. nabla_w[indice2(i,j,n)]+=e[i]*x[j];
  76. }
  77. }
  78. return d;
  79. }
  80. void
  81. FullConnectedLayer::update(Real eta){
  82. //Update b
  83. for(size_t i=0;i<m;++i){
  84. b[i]-=eta*nabla_b[i];
  85. }
  86. //Update w
  87. for(size_t i=0;i<m*n;++i){
  88. w[i]-=eta*nabla_w[i];
  89. }
  90. }
  91. }