main.cpp 759 B

12345678910111213141516171819202122232425262728293031323334
  1. #include <iomanip>
  2. #include "layers/layers.hpp"
  3. #include "network.hpp"
  4. #include "mnist/mnist.hpp"
  5. #include <chrono>
  6. using namespace Layer;
  7. int main(int argc,char** argv){
  8. Network N;
  9. size_t nf=4;
  10. ConvolutionLayer L1(1,28,28,5,5,nf);
  11. L1.init(0,1);
  12. ActivationLayer<Sigmoid> L2(nf*24*24);
  13. //Layer::Pooling L3(nf,24,24,2,2);
  14. FullConnectedLayer L4(nf*24*24,10);
  15. L4.init_standard();
  16. ActivationLayer<Sigmoid> L5(10);
  17. L1.name="[Convolutionnal]";
  18. L2.name="[Sigmoid of convolutionnal]";
  19. //L3.name="[Pooling]";
  20. L4.name="[Full connected]";
  21. L5.name="[Sigmoid of full]";
  22. N.push_layer(L1);
  23. N.push_layer(L2);
  24. // N.push_layer(&L3);
  25. N.push_layer(L4);
  26. N.push_layer(L5);
  27. N.is_done();
  28. Mnist dataset;
  29. N.train(&dataset,1,10,0.1);
  30. }