network.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #include "network.hpp"
  2. void Network::init_normal_distribution(double m,double d){
  3. default_random_engine generator;
  4. normal_distribution<double> distribution(m,d);
  5. for(size_t l=1;l<depth;++l){
  6. Vector& b=biais[l];
  7. for(size_t i=0;i<sizes[l];++i){
  8. b.data[i]=distribution(generator);
  9. }
  10. Matrix& w=weights[l];
  11. for(size_t i=0;i<sizes[l]*sizes[l-1];++i){
  12. w.get(i)=distribution(generator);
  13. }
  14. }
  15. }
  16. void Network::init_standard(){
  17. default_random_engine generator;
  18. normal_distribution<double> distribution(0,1);
  19. for(size_t l=1;l<depth;++l){
  20. Vector& b=biais[l];
  21. for(size_t i=0;i<sizes[l];++i){
  22. b.data[i]=distribution(generator);
  23. }
  24. Matrix& w=weights[l];
  25. for(size_t i=0;i<sizes[l]*sizes[l-1];++i){
  26. normal_distribution<double> distribution2(0,1/sqrt(sizes[l-1]));
  27. w.get(i)=distribution2(generator);
  28. }
  29. }
  30. }
  31. const Vector&
  32. Network::feed_forward(const Vector& x){
  33. a[0]=x;
  34. for(size_t l=1;l<depth;++l){
  35. compute_z(l);
  36. compute_a(l);
  37. }
  38. return a[depth-1];
  39. }
  40. double
  41. Network::eval(Dataset* dataset){
  42. size_t n=dataset->get_test_size();
  43. size_t nb=0;
  44. for(size_t i=0;i<n;++i){
  45. pair<const Vector&,const Vector&> t=dataset->get_test(i);
  46. const Vector& a=feed_forward(t.first);
  47. if(a.argmax()==t.second.argmax()) ++nb;
  48. }
  49. double res=double(nb)/double(n)*100;
  50. cout<<"> Res = "<<res<<"%"<<endl;
  51. return res;
  52. }
  53. void
  54. Network::train(Dataset* dataset,size_t nb_epochs,size_t batch_size,double eta){
  55. size_t train_size=dataset->get_train_size();
  56. size_t nb_batchs=(train_size-1)/batch_size+1;
  57. size_t* indices=new size_t[train_size];
  58. for(size_t i=0;i<train_size;++i){
  59. indices[i]=i;
  60. }
  61. for(size_t epoch=0;epoch<nb_epochs;++epoch){
  62. cout<<"Epoch "<<epoch<<endl;
  63. shuffle(indices,train_size);
  64. for(size_t batch=0;batch<nb_batchs;++batch){
  65. size_t begin=batch*batch_size;
  66. size_t end=min(train_size,begin+batch_size);
  67. update_batch(dataset,indices,begin,end,eta);
  68. }
  69. eval(dataset);
  70. }
  71. delete[] indices;
  72. }
  73. void
  74. Network::shuffle(size_t* tab,size_t size){
  75. default_random_engine generator;
  76. uniform_int_distribution<int> distribution(0,size-1);
  77. for(size_t k=0;k<size;++k){
  78. size_t i=distribution(generator);
  79. size_t j=distribution(generator);
  80. swap(tab[i],tab[j]);
  81. }
  82. }
  83. void
  84. Network::update_batch(Dataset* dataset,size_t* indices,size_t begin,size_t end,double eta){
  85. double batch_size=end-begin;
  86. for(size_t l=1;l<depth;++l){
  87. init_nabla_b(l);
  88. init_nabla_w(l);
  89. }
  90. for(size_t i=begin;i<end;++i){
  91. pair<const Vector&,const Vector&> data=dataset->get_train(indices[i]);
  92. back_propagation(data.first,data.second,eta);
  93. }
  94. double eta_batch=eta/batch_size;
  95. for(size_t l=1;l<depth;++l){
  96. update_b(l,eta_batch);
  97. update_w(l,eta_batch);
  98. }
  99. }
  100. void
  101. Network::back_propagation(const Vector&x, const Vector& y,double eta){
  102. a[0]=x;
  103. for(size_t l=1;l<depth;++l){
  104. compute_z(l);
  105. compute_a(l);
  106. }
  107. compute_last_delta(y);
  108. for(size_t l=depth-2;l>=1;--l){
  109. compute_delta(l);
  110. }
  111. for(size_t l=1;l<depth;++l){
  112. update_nabla_b(l);
  113. update_nabla_w(l);
  114. }
  115. }
  116. void Network::init_nabla_b(size_t l){
  117. Vector& V=nabla_b[l];
  118. for(size_t i=0;i<sizes[l];++i){
  119. V.data[i]=0;
  120. }
  121. }
  122. void Network::init_nabla_w(size_t l){
  123. Matrix& M=nabla_w[l];
  124. for(size_t i=0;i<sizes[l-1]*sizes[l];++i){
  125. M.get(i)=0;
  126. }
  127. }
  128. void Network::compute_a(size_t l){
  129. for(size_t i=0;i<sizes[l];++i){
  130. a[l].data[i]=sigmoid(z[l].data[i]);
  131. }
  132. }
  133. void Network::compute_z(size_t l){
  134. for(size_t i=0;i<sizes[l];++i){
  135. double temp=biais[l].data[i];
  136. for(size_t j=0;j<sizes[l-1];++j){
  137. temp+=weights[l].get(i,j)*a[l-1].data[j];
  138. }
  139. z[l].data[i]=temp;
  140. }
  141. }
  142. void
  143. Network::compute_last_delta(const Vector& y){
  144. size_t L=depth-1;
  145. for(size_t i=0;i<sizes[L];++i){
  146. delta[L].data[i]=cost_derivative(a[L].data[i],y.data[i])*sigmoid_prime(z[L].data[i]);
  147. }
  148. }
  149. void
  150. Network::compute_delta(size_t l){
  151. for(size_t i=0;i<sizes[l];++i){
  152. double temp=0;
  153. for(size_t j=0;j<sizes[l+1];++j){
  154. temp+=(weights[l+1].get(j,i)*delta[l+1].data[j]);
  155. }
  156. delta[l].data[i]=temp*sigmoid_prime(z[l].data[i]);
  157. }
  158. }
  159. void
  160. Network::update_nabla_b(size_t l){
  161. for(size_t i=0;i<sizes[l];++i){
  162. nabla_b[l].data[i]+=delta[l].data[i];
  163. }
  164. }
  165. void
  166. Network::update_nabla_w(size_t l){
  167. for(size_t i=0;i<sizes[l];++i){
  168. for(size_t j=0;j<sizes[l-1];++j){
  169. nabla_w[l].get(i,j)+=a[l-1].data[j]*delta[l].data[i];
  170. }
  171. }
  172. }
  173. void
  174. Network::update_b(size_t l,double eta_batch){
  175. Vector& U=biais[l];
  176. Vector& V=nabla_b[l];
  177. for(size_t i=0;i<sizes[l];++i){
  178. U.data[i]-=V.data[i]*eta_batch;
  179. }
  180. }
  181. void
  182. Network::update_w(size_t l,double eta_batch){
  183. Matrix& M=weights[l];
  184. Matrix& P=nabla_w[l];
  185. for(size_t i=0;i<sizes[l-1]*sizes[l];++i){
  186. M.get(i)-=P.get(i)*eta_batch;
  187. }
  188. }