network.hpp 855 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. #ifndef NETWORK_HPP
  2. #define NETWORK_HPP
  3. #include <random>
  4. #include <list>
  5. #include "layers/layer.hpp"
  6. #include "dataset.hpp"
  7. enum CostFunction{CrossEntropy,Quadratic};
  8. class Network{
  9. public:
  10. list<Layer::Layer*> layers;
  11. size_t n_in;
  12. size_t n_out;
  13. Vector a;
  14. Vector last_delta;
  15. CostFunction C;
  16. void compute_last_delta(Vector y);
  17. protected:
  18. void shuffle(size_t* tab,size_t size);
  19. void update_batch(Dataset* dataset,size_t* indices,size_t begin,size_t end,Real eta);
  20. void back_propagation(Vector x,Vector y,Real eta);
  21. public:
  22. Network();
  23. void set_cost(CostFunction);
  24. void push_layer(Layer::Layer& l);
  25. void is_done();
  26. Vector feed_forward(Vector x_in);
  27. Real eval(Dataset *dataset);
  28. void train(Dataset* dataset,size_t nb_epochs,size_t batch_size,Real eta);
  29. };
  30. inline void
  31. Network::set_cost(CostFunction C_){
  32. C=C_;
  33. }
  34. #endif