12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- #include "mnist.hpp"
- int
- Mnist::reverse_int32(unsigned char* buffer){
- return (((int)buffer[0])<<24)|(((int)buffer[1])<<16)|(((int)buffer[2])<<8)|((int)buffer[3]);
- }
- Mnist::Mnist():Dataset(){
- x_size=784;
- y_size=10;
- x.resize(784);
- y.resize(10);
- train_size=load_labels("mnist/train-labels.idx1-ubyte",&train_labels);
- size_t temp=load_images("mnist/train-images.idx3-ubyte",&train_images);
- assert(train_size==temp);
- test_size=load_labels("mnist/t10k-labels.idx1-ubyte",&test_labels);
- temp=load_images("mnist/t10k-images.idx3-ubyte",&test_images);
- assert(test_size==temp);
- }
- size_t
- Mnist::load_labels(string filename,unsigned char** dst){
- ifstream file(filename,ios::in|ios::binary);
- if(not file.is_open()){
- cerr<<"[error] Could not open "<<filename<<endl;
- exit(-1);
- }
- unsigned char buffer[4];
- file.read((char*)buffer,4);
- assert(reverse_int32(buffer)==2049);
- file.read((char*)buffer,4);
- int size;
- size=reverse_int32(buffer);
- *dst=new unsigned char[size];
- file.read((char*)*dst,size);
- file.close();
- return size;
- }
- size_t
- Mnist::load_images(string filename,unsigned char** dst){
- ifstream file(filename,ios::in|ios::binary);
- if(not file.is_open()){
- cerr<<"[error] Could not open "<<filename<<endl;
- exit(-1);
- }
- unsigned char buffer[4];
- int size;
- file.read((char*)buffer,4);
- assert(reverse_int32(buffer)==2051);
- file.read((char*)buffer,4);
- size=reverse_int32(buffer);
- file.read((char*)buffer,4);
- assert(reverse_int32(buffer)==28);
- file.read((char*)buffer,4);
- assert(reverse_int32(buffer)==28);
- *dst=new unsigned char[784*size];
- file.read((char*)*dst,784*size);
- file.close();
- return size;
- }
- pair<const Vector&,const Vector&>
- Mnist::get(const size_t i,const unsigned char* const * labels,const unsigned char* const * images) const{
- size_t c=(size_t)(*labels)[i];
- for(size_t i=0;i<10;++i) y.data[i]=0;
- y.data[c]=1;
- const unsigned char* x_src=&(*images)[784*i];
- for(size_t i=0;i<784;++i){
- x.data[i]=double(x_src[i])/256.0;
- }
- return pair<const Vector&,const Vector&>(x,y);
- }
|