network.hpp 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #ifndef NETWORK_HPP
  2. #define NETWORK_HPP
  3. #include <iostream>
  4. #include <vector>
  5. #include <random>
  6. #include <algorithm>
  7. #include "dataset.hpp"
  8. #include "vector.hpp"
  9. #include "matrix.hpp"
  10. using namespace std;
  11. class Trainer;
  12. double sigmoid(double x);
  13. double sigmoid_prime(double x);
  14. double cost_derivative(double a,double y);
  15. class Network{
  16. friend class Trainer;
  17. protected:
  18. size_t depth;
  19. vector<size_t> sizes;
  20. Matrix* weights;
  21. Vector *biais;
  22. Vector *a;
  23. Vector *z;
  24. Vector *nabla_b;
  25. Matrix *nabla_w;
  26. Vector *delta;
  27. void shuffle(size_t* tab,size_t size);
  28. void compute_z(size_t l);
  29. void compute_a(size_t l);
  30. void compute_last_delta(const Vector& y);
  31. void compute_delta(size_t l);
  32. void init_nabla_b(size_t l);
  33. void init_nabla_w(size_t l);
  34. void update_nabla_b(size_t l);
  35. void update_nabla_w(size_t l);
  36. void update_b(size_t l,double eta_batch);
  37. void update_w(size_t l,double eta_batch);
  38. public:
  39. template<typename ... Sizes> Network(Sizes ... _sizes);
  40. void init_normal_distribution(double m,double d);
  41. void init_standard();
  42. double* new_output_vector() const;
  43. const Vector& feed_forward(const Vector& x);
  44. double eval(Dataset* dataset);
  45. void train(Dataset* dataset,size_t nb_epochs,size_t batch_size,double eta);
  46. void update_batch(Dataset* dataset,size_t* indices,size_t begin,size_t end,double eta);
  47. void back_propagation(const Vector& x,const Vector& y,double eta);
  48. Vector hack(const Vector& x,const Vector& y,double eta,size_t nb_steps,void (*)(const Vector&));
  49. };
  50. inline double
  51. sigmoid(double x){
  52. return 1.0/(1.0+exp(-x));
  53. };
  54. inline double
  55. sigmoid_prime(double x){
  56. double t=sigmoid(x);
  57. return t*(1.0-t);
  58. };
  59. template<typename ... Sizes> inline
  60. Network::Network(Sizes ... _sizes):sizes({(size_t)_sizes ...}){
  61. depth=sizes.size();
  62. // Biais vectors
  63. biais=new Vector[depth];
  64. for(size_t l=0;l<depth;++l){
  65. biais[l].resize(sizes[l]);
  66. }
  67. // Weights vectors
  68. weights=new Matrix[depth];
  69. for(size_t l=1;l<depth;++l){
  70. weights[l].resize(sizes[l],sizes[l-1]);
  71. }
  72. // Activation vectors
  73. a=new Vector[depth];
  74. for(size_t l=0;l<depth;++l){
  75. a[l].resize(sizes[l]);
  76. }
  77. // Activation vectors
  78. z=new Vector[depth];
  79. for(size_t l=0;l<depth;++l){
  80. z[l].resize(sizes[l]);
  81. }
  82. nabla_b=new Vector[depth];
  83. for(size_t l=0;l<depth;++l){
  84. nabla_b[l].resize(sizes[l]);
  85. }
  86. nabla_w=new Matrix[depth];
  87. for(size_t l=1;l<depth;++l){
  88. nabla_w[l].resize(sizes[l],sizes[l-1]);
  89. }
  90. delta=new Vector[depth];
  91. for(size_t l=0;l<depth;++l){
  92. delta[l].resize(sizes[l]);
  93. }
  94. }
  95. inline double
  96. cost_derivative(double a,double y){
  97. return a-y;
  98. }
  99. #endif