File size: 1,451 Bytes
db3c893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "data_loader.hpp"
#include <fstream>
#include <stdexcept>
#include <iostream>
#include "optical_model.hpp" // For IMG_SIZE

FashionMNISTSet load_fashion_mnist_data(const std::string& data_dir, bool is_train) {
    FashionMNISTSet set;
    const std::string prefix = is_train ? "train" : "test";
    const std::string images_path = data_dir + "/" + prefix + "-images.bin";
    const std::string labels_path = data_dir + "/" + prefix + "-labels.bin";

    // Load images
    std::ifstream f_images(images_path, std::ios::binary);
    if (!f_images) throw std::runtime_error("Cannot open: " + images_path);
    f_images.seekg(0, std::ios::end);
    size_t num_bytes = f_images.tellg();
    f_images.seekg(0, std::ios::beg);
    set.N = num_bytes / (IMG_SIZE * sizeof(float));
    set.images.resize(set.N * IMG_SIZE);
    f_images.read(reinterpret_cast<char*>(set.images.data()), num_bytes);

    // Load labels
    std::ifstream f_labels(labels_path, std::ios::binary);
    if (!f_labels) throw std::runtime_error("Cannot open: " + labels_path);
    f_labels.seekg(0, std::ios::end);
    num_bytes = f_labels.tellg();
    f_labels.seekg(0, std::ios::beg);
    if (set.N != num_bytes) throw std::runtime_error("Image and label count mismatch!");
    set.labels.resize(set.N);
    f_labels.read(reinterpret_cast<char*>(set.labels.data()), num_bytes);

    std::cout << "[INFO] Loaded " << set.N << " " << prefix << " samples.\n";
    return set;
}