123456789101112131415161718192021222324252627282930313233343536373839 |
- #ifndef NETWORK_HPP
- #define NETWORK_HPP
- #include <random>
- #include <list>
- #include "layers/layer.hpp"
- #include "dataset.hpp"
- enum CostFunction{CrossEntropy,Quadratic};
- class Network{
- public:
- list<Layer::Layer*> layers;
- size_t n_in;
- size_t n_out;
- Vector a;
- Vector last_delta;
- CostFunction C;
- void compute_last_delta(Vector y);
- protected:
- void shuffle(size_t* tab,size_t size);
- void update_batch(Dataset* dataset,size_t* indices,size_t begin,size_t end,Real eta);
- void back_propagation(Vector x,Vector y,Real eta);
- public:
- Network();
- void set_cost(CostFunction);
- void push_layer(Layer::Layer& l);
- void is_done();
- Vector feed_forward(Vector x_in);
- Real eval(Dataset *dataset);
- void train(Dataset* dataset,size_t nb_epochs,size_t batch_size,Real eta);
- };
- inline void
- Network::set_cost(CostFunction C_){
- C=C_;
- }
- #endif
|