mnist.hpp 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. #ifndef MNIST_HPP
  2. #define MNIST_HPP
  3. #include <iostream>
  4. #include <string>
  5. #include <fstream>
  6. #include <cstdint>
  7. #include <cassert>
  8. #include "../dataset.hpp"
  9. using namespace std;
  10. class Mnist:public Dataset{
  11. private:
  12. unsigned char* train_labels;
  13. unsigned char* test_labels;
  14. unsigned char* train_images;
  15. unsigned char* test_images;
  16. mutable Vector x;
  17. mutable Vector y;
  18. size_t load_labels(string filename,unsigned char** dst);
  19. size_t load_images(string filename,unsigned char** dst);
  20. int reverse_int32(unsigned char* buffer);
  21. pair<const Vector&,const Vector&> get(const size_t i,const unsigned char* const* labels,const unsigned char* const * images) const;
  22. public:
  23. Mnist();
  24. pair<const Vector&,const Vector&> get_train(const size_t i) const;
  25. pair<const Vector&,const Vector&> get_test(const size_t i) const;
  26. };
  27. inline pair<const Vector&,const Vector&>
  28. Mnist::get_train(const size_t i) const{
  29. assert(i<train_size);
  30. return get(i,&train_labels,&train_images);
  31. }
  32. inline pair<const Vector&,const Vector&>
  33. Mnist::get_test(const size_t i) const{
  34. assert(i<test_size);
  35. return get(i,&test_labels,&test_images);
  36. }
  37. #endif