mnist.hpp 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #ifndef MNIST_HPP
  2. #define MNIST_HPP
  3. #include <iostream>
  4. #include <string>
  5. #include <fstream>
  6. #include <cstdint>
  7. #include "../dataset.hpp"
  8. using namespace std;
  9. class Mnist:public Dataset{
  10. private:
  11. unsigned char* train_labels;
  12. unsigned char* test_labels;
  13. unsigned char* train_images;
  14. unsigned char* test_images;
  15. mutable Vector x;
  16. mutable Vector y;
  17. size_t load_labels(string filename,unsigned char** dst);
  18. size_t load_images(string filename,unsigned char** dst);
  19. int reverse_int32(unsigned char* buffer);
  20. pair<Vector,Vector> get(const size_t i,const unsigned char* const* labels,const unsigned char* const * images) const;
  21. public:
  22. Mnist();
  23. pair<Vector,Vector> get_train(const size_t i) const;
  24. pair<Vector,Vector> get_test(const size_t i) const;
  25. };
  26. inline pair<Vector,Vector>
  27. Mnist::get_train(const size_t i) const{
  28. assert(i<train_size);
  29. return get(i,&train_labels,&train_images);
  30. }
  31. inline pair<Vector,Vector>
  32. Mnist::get_test(const size_t i) const{
  33. assert(i<test_size);
  34. return get(i,&test_labels,&test_images);
  35. }
  36. #endif