12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- #ifndef MNIST_HPP
- #define MNIST_HPP
- #include <iostream>
- #include <string>
- #include <fstream>
- #include <cstdint>
- #include <cassert>
- #include "../dataset.hpp"
- using namespace std;
- class Mnist:public Dataset{
- private:
- unsigned char* train_labels;
- unsigned char* test_labels;
- unsigned char* train_images;
- unsigned char* test_images;
- mutable Vector x;
- mutable Vector y;
- size_t load_labels(string filename,unsigned char** dst);
- size_t load_images(string filename,unsigned char** dst);
- int reverse_int32(unsigned char* buffer);
- pair<const Vector&,const Vector&> get(const size_t i,const unsigned char* const* labels,const unsigned char* const * images) const;
- public:
- Mnist();
- pair<const Vector&,const Vector&> get_train(const size_t i) const;
- pair<const Vector&,const Vector&> get_test(const size_t i) const;
- };
- inline pair<const Vector&,const Vector&>
- Mnist::get_train(const size_t i) const{
- assert(i<train_size);
- return get(i,&train_labels,&train_images);
- }
- inline pair<const Vector&,const Vector&>
- Mnist::get_test(const size_t i) const{
- assert(i<test_size);
- return get(i,&test_labels,&test_images);
- }
- #endif
|