layer.hpp 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #ifndef LAYER_HPP
  2. #define LAYER_HPP
  3. #include "debug.hpp"
  4. #include "vector.hpp"
  5. #include "shape.hpp"
  6. #include <cmath>
  7. namespace Layer{
  8. /**
  9. * An abstract class representing a layer of a deep network.
  10. * Such a layer can be seen as a map
  11. \f[
  12. \begin{array}{rcl}
  13. f:\mathbb{R}^n&\mapsto&\mathbb{R}^m\\
  14. x&\to&y
  15. \end{array}
  16. \f].
  17. */
  18. class Layer{
  19. public:
  20. /** The name of the layer. Used for debugging.*/
  21. string name;
  22. /** Size of the input vector.*/
  23. size_t n;
  24. /** Size of the output vector.*/
  25. size_t m;
  26. /** A reference to the input vector.*/
  27. Vector x;
  28. /** Computed output vector. Owned by the layer.*/
  29. Vector y;
  30. /** Computed input delta vector computed by back propagation algorithm. Owned by the layer.*/
  31. Vector d;
  32. Layer(size_t n,size_t m);
  33. ~Layer();
  34. /** Return the input size.*/
  35. size_t get_input_size() const;
  36. /** Return the output size.*/
  37. size_t get_output_size() const;
  38. /** Return a reference to the computed output vector.*/
  39. Vector get_output() const;
  40. /** Apply the layer to the input vector \c x. Vectors \c x_in_ref and \c x_out must be updated in consequence. Return a reference to x_out.*/
  41. virtual Vector feed_forward(Vector x)=0;
  42. /** Initialize nabla vectors which are used during gradient descent.*/
  43. virtual void init_nabla()=0;
  44. /** Apply back propagation algorithm on the delta output vector d. Used the input vector stored in x_in_ref during feedforward.
  45. Return a reference to the computed (and stored) input delta vector. Nabla vectors must be computed here.*/
  46. virtual Vector back_propagation(Vector e)=0;
  47. /** Update layer parameters using gradient descent algorithm with learning rate eta. */
  48. virtual void update(Real eta)=0;
  49. };
  50. inline
  51. Layer::Layer(size_t n_,size_t m_){
  52. n=n_;
  53. m=m_;
  54. y=init_vector(m);
  55. d=init_vector(n);
  56. }
  57. inline
  58. Layer::~Layer(){
  59. delete_vector(y);
  60. delete_vector(d);
  61. }
  62. inline size_t
  63. Layer::get_input_size() const{
  64. return n;
  65. }
  66. inline size_t
  67. Layer::get_output_size() const{
  68. return m;
  69. }
  70. inline Vector
  71. Layer::get_output() const{
  72. return y;
  73. }
  74. }
  75. #endif