123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- #include "full_connected.hpp"
- namespace Layer{
- FullConnectedLayer::FullConnectedLayer(size_t n,size_t m):Layer(n,m){
- b=init_vector(m); //b
- w=init_vector(m*n); //w
- nabla_b=init_vector(m); //nabla_b
- nabla_w=init_vector(m*n); //nabla_w
- }
- FullConnectedLayer::~FullConnectedLayer(){
- delete_vector(b);
- delete_vector(w);
- delete_vector(nabla_b);
- delete_vector(nabla_w);
- }
- void
- FullConnectedLayer::init(Real mu,Real sigma){
- default_random_engine generator;
- normal_distribution<Real> distribution(mu,sigma);
- for(size_t i=0;i<m;++i){
- b[i]=distribution(generator);
- }
- for(size_t i=0;i<m*n;++i){
- w[i]=distribution(generator);
- }
- }
- void
- FullConnectedLayer::init_standard(){
- default_random_engine generator;
- normal_distribution<Real> distribution(0,1);
- for(size_t i=0;i<m;++i){
- b[i]=distribution(generator);
- }
- for(size_t i=0;i<m*n;++i){
- normal_distribution<Real> distribution2(0,1/sqrt(n));
- w[i]=distribution2(generator);
- }
- }
- Vector
- FullConnectedLayer::feed_forward(Vector x_){
- x=x_;
- for(size_t i=0;i<m;++i){
- Real temp=b[i];
- for(size_t j=0;j<n;++j){
- temp+=w[indice2(i,j,n)]*x[j];
- }
- y[i]=temp;
- }
- return y;
- }
- void
- FullConnectedLayer::init_nabla(){
- for(size_t i=0;i<m;++i){
- nabla_b[i]=0;
- }
- for(size_t i=0;i<m*n;++i){
- nabla_w[i]=0;
- }
- }
- Vector
- FullConnectedLayer::back_propagation(Vector e){
- for(size_t i=0;i<n;++i){
- Real temp=0;
- for(size_t j=0;j<m;++j){
- temp+=w[indice2(j,i,n)]*e[j];
- }
- d[i]=temp;
- }
- //Update nabla_b
- for(size_t i=0;i<m;++i){
- nabla_b[i]+=e[i];
- }
- //Update nabla_w
- for(size_t i=0;i<m;++i){
- for(size_t j=0;j<n;++j){
- nabla_w[indice2(i,j,n)]+=e[i]*x[j];
- }
- }
- return d;
- }
- void
- FullConnectedLayer::update(Real eta){
- //Update b
- for(size_t i=0;i<m;++i){
- b[i]-=eta*nabla_b[i];
- }
- //Update w
- for(size_t i=0;i<m*n;++i){
- w[i]-=eta*nabla_w[i];
- }
- }
- }
|