mnist.cpp 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #include "mnist.hpp"
  2. int
  3. Mnist::reverse_int32(unsigned char* buffer){
  4. return (((int)buffer[0])<<24)|(((int)buffer[1])<<16)|(((int)buffer[2])<<8)|((int)buffer[3]);
  5. }
  6. Mnist::Mnist():Dataset(){
  7. x_size=784;
  8. y_size=10;
  9. x.resize(784);
  10. y.resize(10);
  11. train_size=load_labels("mnist/train-labels.idx1-ubyte",&train_labels);
  12. size_t temp=load_images("mnist/train-images.idx3-ubyte",&train_images);
  13. assert(train_size==temp);
  14. test_size=load_labels("mnist/t10k-labels.idx1-ubyte",&test_labels);
  15. temp=load_images("mnist/t10k-images.idx3-ubyte",&test_images);
  16. assert(test_size==temp);
  17. }
  18. size_t
  19. Mnist::load_labels(string filename,unsigned char** dst){
  20. ifstream file(filename,ios::in|ios::binary);
  21. if(not file.is_open()){
  22. cerr<<"[error] Could not open "<<filename<<endl;
  23. exit(-1);
  24. }
  25. unsigned char buffer[4];
  26. file.read((char*)buffer,4);
  27. assert(reverse_int32(buffer)==2049);
  28. file.read((char*)buffer,4);
  29. int size;
  30. size=reverse_int32(buffer);
  31. *dst=new unsigned char[size];
  32. file.read((char*)*dst,size);
  33. file.close();
  34. return size;
  35. }
  36. size_t
  37. Mnist::load_images(string filename,unsigned char** dst){
  38. ifstream file(filename,ios::in|ios::binary);
  39. if(not file.is_open()){
  40. cerr<<"[error] Could not open "<<filename<<endl;
  41. exit(-1);
  42. }
  43. unsigned char buffer[4];
  44. int size;
  45. file.read((char*)buffer,4);
  46. assert(reverse_int32(buffer)==2051);
  47. file.read((char*)buffer,4);
  48. size=reverse_int32(buffer);
  49. file.read((char*)buffer,4);
  50. assert(reverse_int32(buffer)==28);
  51. file.read((char*)buffer,4);
  52. assert(reverse_int32(buffer)==28);
  53. *dst=new unsigned char[784*size];
  54. file.read((char*)*dst,784*size);
  55. file.close();
  56. return size;
  57. }
  58. pair<const Vector&,const Vector&>
  59. Mnist::get(const size_t i,const unsigned char* const * labels,const unsigned char* const * images) const{
  60. size_t c=(size_t)(*labels)[i];
  61. for(size_t i=0;i<10;++i) y.data[i]=0;
  62. y.data[c]=1;
  63. const unsigned char* x_src=&(*images)[784*i];
  64. for(size_t i=0;i<784;++i){
  65. x.data[i]=double(x_src[i])/256.0;
  66. }
  67. return pair<const Vector&,const Vector&>(x,y);
  68. }