main.cpp 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #include <iomanip>
  2. #include "layers/layers.hpp"
  3. #include "network.hpp"
  4. #include "mnist/mnist.hpp"
  5. int main(int argc,char** argv){
  6. Network N;
  7. size_t nf=4;
  8. Layer::Convolution L1(1,28,28,5,5,nf);
  9. L1.init(0,1);
  10. Layer::Activation<Layer::Sigmoid> L2(nf*24*24);
  11. //Layer::Pooling L3(nf,24,24,2,2);
  12. Layer::FullConnected L4(nf*24*24,10);
  13. L4.init_standard();
  14. Layer::Activation<Layer::Sigmoid> L5(10);
  15. L1.name="[Convolutionnal]";
  16. L2.name="[Sigmoid of convolutionnal]";
  17. //L3.name="[Pooling]";
  18. L4.name="[Full connected]";
  19. L5.name="[Sigmoid of full]";
  20. N.push_layer(&L1);
  21. N.push_layer(&L2);
  22. // N.push_layer(&L3);
  23. N.push_layer(&L4);
  24. N.push_layer(&L5);
  25. N.is_done();
  26. Mnist dataset;
  27. N.train(&dataset,10,10,0.1);
  28. //exit(0);
  29. /* Network N;
  30. size_t nf=4;
  31. Layer::Convolution L1(1,28,28,5,5,nf);
  32. L1.init(0,1);
  33. Layer::Activation<Layer::Sigmoid> L2(nf*24*24);
  34. Layer::FullConnected L3(nf*24*24,10);
  35. L3.init_standard();
  36. Layer::Activation<Layer::Sigmoid> L4(10);
  37. L1.name="[Convolutionnal]";
  38. L2.name="[Sigmoid of convolutionnal]";
  39. L3.name="[Full connected]";
  40. L4.name="[Sigmoid of full]";
  41. N.push_layer(&L1);
  42. N.push_layer(&L2);
  43. N.push_layer(&L3);
  44. N.push_layer(&L4);
  45. N.is_done();
  46. cout<<"Network out size = "<<N.n_out<<endl;
  47. Mnist dataset;
  48. N.train(&dataset,20,10,0.1);
  49. */
  50. /* Network N;
  51. size_t n=20;
  52. Layer::FullConnected L0(28*28,n);
  53. L0.init_standard();
  54. Layer::Activation<Layer::Sigmoid> L1(n);
  55. Layer::FullConnected L2(n,10);
  56. L2.init_standard();
  57. Layer::Activation<Layer::Sigmoid> L3(10);
  58. L0.name="[Full 0]";
  59. L1.name="[Sigmoid 0]";
  60. L2.name="[Full 1]";
  61. L3.name="[Sigmoid 1]";
  62. N.push_layer(&L0);
  63. N.push_layer(&L1);
  64. N.push_layer(&L2);
  65. N.push_layer(&L3);
  66. N.is_done();
  67. cout<<"Network out size = "<<N.n_out<<endl;
  68. Mnist dataset;
  69. N.train(&dataset,20,10,3.0);*/
  70. }