Deep network
mnist.hpp
Go to the documentation of this file.
1#ifndef MNIST_HPP
2#define MNIST_HPP
3
4#include <iostream>
5#include <string>
6#include <fstream>
7#include <cstdint>
8
9#include "../dataset.hpp"
10
11using namespace std;
12
13class Mnist:public Dataset{
14private:
15 unsigned char* train_labels;
16 unsigned char* test_labels;
17 unsigned char* train_images;
18 unsigned char* test_images;
19 mutable Vector x;
20 mutable Vector y;
21 size_t load_labels(string filename,unsigned char** dst);
22 size_t load_images(string filename,unsigned char** dst);
23 int reverse_int32(unsigned char* buffer);
24 pair<Vector,Vector> get(const size_t i,const unsigned char* const* labels,const unsigned char* const * images) const;
25public:
26 Mnist();
27 pair<Vector,Vector> get_train(const size_t i) const;
28 pair<Vector,Vector> get_test(const size_t i) const;
29};
30
31inline pair<Vector,Vector>
32Mnist::get_train(const size_t i) const{
34 return get(i,&train_labels,&train_images);
35}
36
37inline pair<Vector,Vector>
38Mnist::get_test(const size_t i) const{
40 return get(i,&test_labels,&test_images);
41}
42#endif
Definition: dataset.hpp:9
size_t train_size
Definition: dataset.hpp:11
size_t test_size
Definition: dataset.hpp:12
Definition: mnist.hpp:13
pair< Vector, Vector > get_train(const size_t i) const
Definition: mnist.hpp:32
Mnist()
Definition: mnist.cpp:8
pair< Vector, Vector > get_test(const size_t i) const
Definition: mnist.hpp:38
#define assert(cond)
Definition: debug.hpp:11
Real * Vector
Definition: vector.hpp:45