12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- #ifndef LAYER_HPP
- #define LAYER_HPP
- #include "debug.hpp"
- #include "vector.hpp"
- #include "shape.hpp"
- #include <cmath>
- namespace Layer{
- /**
- * An abstract class representing a layer of a deep network.
- * Such a layer can be seen as a map
- \f[
- \begin{array}{rcl}
- f:\mathbb{R}^n&\mapsto&\mathbb{R}^m\\
- x&\to&y
- \end{array}
- \f].
- */
- class Layer{
- public:
- /** The name of the layer. Used for debugging.*/
- string name;
- /** Size of the input vector.*/
- size_t n;
- /** Size of the output vector.*/
- size_t m;
- /** A reference to the input vector.*/
- Vector x;
- /** Computed output vector. Owned by the layer.*/
- Vector y;
- /** Computed input delta vector computed by back propagation algorithm. Owned by the layer.*/
- Vector d;
- Layer(size_t n,size_t m);
- ~Layer();
- /** Return the input size.*/
- size_t get_input_size() const;
- /** Return the output size.*/
- size_t get_output_size() const;
- /** Return a reference to the computed output vector.*/
- Vector get_output() const;
- /** Apply the layer to the input vector \c x. Vectors \c x_in_ref and \c x_out must be updated in consequence. Return a reference to x_out.*/
- virtual Vector feed_forward(Vector x)=0;
- /** Initialize nabla vectors which are used during gradient descent.*/
- virtual void init_nabla()=0;
- /** Apply back propagation algorithm on the delta output vector d. Used the input vector stored in x_in_ref during feedforward.
- Return a reference to the computed (and stored) input delta vector. Nabla vectors must be computed here.*/
- virtual Vector back_propagation(Vector e)=0;
- /** Update layer parameters using gradient descent algorithm with learning rate eta. */
- virtual void update(Real eta)=0;
- };
- inline
- Layer::Layer(size_t n_,size_t m_){
- n=n_;
- m=m_;
- y=init_vector(m);
- d=init_vector(n);
- }
- inline
- Layer::~Layer(){
- delete_vector(y);
- delete_vector(d);
- }
- inline size_t
- Layer::get_input_size() const{
- return n;
- }
- inline size_t
- Layer::get_output_size() const{
- return m;
- }
- inline Vector
- Layer::get_output() const{
- return y;
- }
- }
- #endif
|