full_connected.cpp 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #include "full_connected.hpp"
  2. using namespace Layer;
  3. void
  4. FullConnected::init(Real m,Real d){
  5. default_random_engine generator;
  6. normal_distribution<Real> distribution(m,d);
  7. for(size_t i=0;i<n_out;++i){
  8. b[i]=distribution(generator);
  9. }
  10. for(size_t i=0;i<n_out*n_in;++i){
  11. w[i]=distribution(generator);
  12. }
  13. }
  14. void
  15. FullConnected::init_standard(){
  16. default_random_engine generator;
  17. normal_distribution<Real> distribution(0,1);
  18. for(size_t i=0;i<n_out;++i){
  19. b[i]=distribution(generator);
  20. }
  21. for(size_t i=0;i<n_out*n_in;++i){
  22. normal_distribution<Real> distribution2(0,1/sqrt(n_in));
  23. w[i]=distribution2(generator);
  24. }
  25. }
  26. Vector
  27. FullConnected::feed_forward(Vector x){
  28. x_in=x;
  29. for(size_t i=0;i<n_out;++i){
  30. Real temp=b[i];
  31. for(size_t j=0;j<n_in;++j){
  32. temp+=w[indice2(i,j,n_in)]*x[j];
  33. }
  34. x_out[i]=temp;
  35. }
  36. return x_out;
  37. }
  38. void
  39. FullConnected::init_nabla(){
  40. for(size_t i=0;i<n_out;++i){
  41. nabla_b[i]=0;
  42. }
  43. for(size_t i=0;i<n_out*n_in;++i){
  44. nabla_w[i]=0;
  45. }
  46. }
  47. Vector
  48. FullConnected::back_propagation(Vector d){
  49. for(size_t i=0;i<n_in;++i){
  50. Real temp=0;
  51. for(size_t j=0;j<n_out;++j){
  52. temp+=w[indice2(j,i,n_in)]*d[j];
  53. }
  54. delta[i]=temp;
  55. }
  56. //Update nabla_b
  57. for(size_t i=0;i<n_out;++i){
  58. nabla_b[i]+=d[i];
  59. }
  60. //Update nabla_w
  61. for(size_t i=0;i<n_out;++i){
  62. for(size_t j=0;j<n_in;++j){
  63. nabla_w[indice2(i,j,n_in)]+=d[i]*x_in[j];
  64. }
  65. }
  66. return delta;
  67. }
  68. void
  69. FullConnected::update(Real eta){
  70. //Update b
  71. for(size_t i=0;i<n_out;++i){
  72. b[i]-=eta*nabla_b[i];
  73. }
  74. //Update w
  75. for(size_t i=0;i<n_out*n_in;++i){
  76. w[i]-=eta*nabla_w[i];
  77. }
  78. }