network.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #include "network.hpp"
  2. Network::Network(){
  3. C=Quadratic;
  4. }
  5. void
  6. Network::push_layer(Layer::Layer& l){
  7. if(layers.empty()){
  8. n_in=l.get_input_size();
  9. layers.push_back(&l);
  10. cout<<"In size = "<<n_in<<endl;
  11. }
  12. else{
  13. assert(l.get_input_size()==layers.back()->get_output_size());
  14. layers.push_back(&l);
  15. }
  16. n_out=l.get_output_size();
  17. }
  18. void
  19. Network::is_done(){
  20. last_delta=init_vector(n_out);
  21. }
  22. Vector
  23. Network::feed_forward(Vector x_in){
  24. Vector x=x_in;
  25. for(auto it=layers.begin();it!=layers.end();++it){
  26. //cout<<" - Try feed_forward on layer "<<(*it)->name<<endl;
  27. x=(*it)->feed_forward(x);
  28. }
  29. a=x;
  30. return a;
  31. }
  32. Real
  33. Network::eval(Dataset* dataset){
  34. size_t n=dataset->get_test_size();
  35. size_t nb=0;
  36. for(size_t i=0;i<n;++i){
  37. pair<Vector,Vector> t=dataset->get_test(i);
  38. Vector a=feed_forward(t.first);
  39. if(argmax(a,n_out)==argmax(t.second,n_out)) ++nb;
  40. }
  41. Real res=Real(nb)/Real(n)*100;
  42. cout<<"> Res = "<<res<<"%"<<endl;
  43. return res;
  44. }
  45. void
  46. Network::shuffle(size_t* tab,size_t size){
  47. default_random_engine generator;
  48. uniform_int_distribution<int> distribution(0,size-1);
  49. for(size_t k=0;k<size;++k){
  50. size_t i=distribution(generator);
  51. size_t j=distribution(generator);
  52. swap(tab[i],tab[j]);
  53. }
  54. }
  55. void
  56. Network::train(Dataset* dataset,size_t nb_epochs,size_t batch_size,Real eta){
  57. size_t train_size=dataset->get_train_size();
  58. size_t nb_batchs=(train_size-1)/batch_size+1;
  59. size_t* indices=new size_t[train_size];
  60. for(size_t i=0;i<train_size;++i){
  61. indices[i]=i;
  62. }
  63. for(size_t epoch=0;epoch<nb_epochs;++epoch){
  64. cout<<"Epoch "<<epoch<<endl;
  65. shuffle(indices,train_size);
  66. for(size_t batch=0;batch<nb_batchs;++batch){
  67. size_t begin=batch*batch_size;
  68. size_t end=min(train_size,begin+batch_size);
  69. update_batch(dataset,indices,begin,end,eta);
  70. }
  71. eval(dataset);
  72. }
  73. delete[] indices;
  74. }
  75. void
  76. Network::update_batch(Dataset* dataset,size_t* indices,size_t begin,size_t end,Real eta){
  77. Real batch_size=end-begin;
  78. for(auto it=layers.begin();it!=layers.end();++it){
  79. (*it)->init_nabla();
  80. }
  81. for(size_t i=begin;i<end;++i){
  82. pair<Vector,Vector> data=dataset->get_train(indices[i]);
  83. //cout<<"Call back_propagation on batch data "<<i-begin<<"/"<<batch_size<<endl;
  84. back_propagation(data.first,data.second,eta);
  85. }
  86. Real eta_batch=eta/batch_size;
  87. for(auto it=layers.begin();it!=layers.end();++it){
  88. (*it)->update(eta_batch);
  89. }
  90. }
  91. void
  92. Network::back_propagation(Vector x,Vector y,Real eta){
  93. Vector z=feed_forward(x);
  94. //cout<<" - Feed forward done"<<endl;
  95. compute_last_delta(y);
  96. Vector delta=last_delta;
  97. //cout<<" - Last_delta computed"<<endl;
  98. for(auto it=layers.rbegin();it!=layers.rend();++it){
  99. //cout<<" - Try back_propagation on layer "<<(*it)->name<<endl;
  100. delta=(*it)->back_propagation(delta);
  101. //cout<<" - Done"<<endl;
  102. }
  103. }
  104. void
  105. Network::compute_last_delta(Vector y){
  106. switch(C){
  107. case Quadratic:
  108. case CrossEntropy:
  109. for(size_t i=0;i<n_out;++i){
  110. last_delta[i]=a[i]-y[i];
  111. }
  112. break;
  113. default:
  114. assert(false);
  115. break;
  116. }
  117. }