#include "dataset.h"
#include <cnpy.h>
//#define SIM_ENV_WINDOWS

NMNIST::NMNIST(const std::string& dataset_root,
               int time_window,
               bool include_incomplete)
    : dataset_root(dataset_root),
      time_window(time_window),
      include_incomplete(include_incomplete) {
    namespace fs = std::filesystem;
    int cor = 0, tot = 0;
    try {
        for (int i = 0; i < 10; ++i) {
            std::string sub_file(dataset_root);
            sub_file += '/';
            sub_file += (i + '0');
            if (fs::exists(sub_file) && fs::is_directory(sub_file)) {
                for (const auto& entry :
                     fs::recursive_directory_iterator(sub_file)) {
                    auto filePath = entry.path().string();
                    path_label.push_back(make_pair(filePath, i));
                }
            } else {
                std::cout << "NMNIST Error: Missing Category !" << std::endl;
            }
        }
        sort(path_label.begin(), path_label.end());
    } catch (const fs::filesystem_error& e) {
        std::cerr << e.what() << std::endl;
    }
    return;
}

std::vector<std::vector<Event> > NMNIST::get_binned_input(int id) {
    return read_nmnist_event(path_label[id].first, time_window, include_incomplete);
}

int NMNIST::get_label(int id) {
    return path_label[id].second;
}

std::vector<std::vector<Event> > NMNIST::read_nmnist_event(
    const std::string& filePath,
    int time_window,
    bool include_incomplete) {
    uint32_t min_t = 1000000000, max_t = 0;

    std::ifstream file(filePath, std::ios::binary);
    if (!file.is_open()) {
        std::cerr << "Failed to open file\n";
        exit(0);
    }

    std::vector<uint8_t> buffer(std::istreambuf_iterator<char>(file), {});

    std::vector<Event> events;
    const uint32_t time_increment = 1 << 13;
    uint32_t last_overflow = 0;

    for (size_t i = 0; i < buffer.size(); i += 5) {
        uint32_t y = buffer[i + 1];
        uint32_t x = buffer[i];
        uint32_t p = (buffer[i + 2] & 128) >> 7;
        uint32_t ts = ((buffer[i + 2] & 127) << 16) | (buffer[i + 3] << 8) |
                      buffer[i + 4];

        if (y == 240) {  // Time overflow event
            last_overflow += time_increment;
        } else {
            events.push_back(Event(p, y, x, ts + last_overflow));
            min_t = std::min(min_t, ts + last_overflow);
            max_t = std::max(max_t, ts + last_overflow);
        }
    }

    std::vector<std::vector<Event> > ret;

    int rpos = 0, e_num = events.size();
    for (int i = 0; i < e_num; ++i)
        events[i].ts = (events[i].ts - min_t) / time_window;

    sort(events.begin(), events.end());

    for (int l = 0; l < e_num; l = rpos) {
        std::vector<Event> tmp;
        if (!include_incomplete &&
            events[l].ts == (max_t - min_t) / time_window)
            break;
        while (rpos < e_num && events[rpos].ts == events[l].ts) {
            tmp.push_back(events[rpos]);
            ++rpos;
        }
        ret.push_back(tmp);
    }

    return ret;
}

int NMNIST::size() {
    return path_label.size();
}

int extractXFromPath(const std::string& path) {
    #ifdef SIM_ENV_WINDOWS
        size_t lastSlashPos = path.find_last_of('\\');
    #else
        size_t lastSlashPos = path.find_last_of('/');
    #endif
    std::string substring = path.substr(lastSlashPos + 1);
    //puts(substring.c_str());
    size_t lastDotPos = substring.find_last_of('.');
    std::string numberString = substring.substr(0, lastDotPos);
    return std::stoi(numberString);
}

DVSGesture::DVSGesture(std::string dataset_root,
                       int time_window,
                       bool include_incomplete)
    : dataset_root(dataset_root),
      time_window(time_window),
      include_incomplete(include_incomplete) {
    namespace fs = std::filesystem;
    int cor = 0, tot = 0;
    int total_person = 0;
    std::vector<std::string> user_path_vec;

    for (const auto& user : fs::directory_iterator(dataset_root)) {
        user_path_vec.push_back(user.path().string());
    }
    sort(user_path_vec.begin(), user_path_vec.end());
    for (const auto& user_path_str : user_path_vec) {
        std::string sub_file(user_path_str);

        std::vector<std::string> entry_path_vec;
        for (const auto& entry : fs::recursive_directory_iterator(sub_file)) {
            entry_path_vec.push_back(entry.path().string());
        }
        sort(entry_path_vec.begin(), entry_path_vec.end());
        for (const auto& entry_path_str : entry_path_vec) {
            auto filePath = entry_path_str;
            id_path.push_back(filePath);
            id_label.push_back(extractXFromPath(filePath));
        }
        total_person++;
    }
    std::cout << "DVSGesture test_size: " << id_path.size() << std::endl;
    return;
}

std::vector<std::vector<Event> > DVSGesture::get_binned_input(int id) {
    return read_dvsgesture_event(id_path[id], time_window, include_incomplete);
}

int DVSGesture::get_label(int id) {
    return id_label[id];
}

std::vector<std::vector<Event> > DVSGesture::read_dvsgesture_event(
    std::string filePath,
    int time_window,
    bool include_incomplete) {
    uint32_t min_t = 1000000000, max_t = 0;

    cnpy::NpyArray arr = cnpy::npy_load(filePath);

    std::vector<Event> events;

    // 访问数组数据
    double* data = arr.data<double>();
    size_t numElements = arr.num_vals;
    for (size_t i = 0; i < numElements; i += 4) {
        events.push_back(
            Event(data[i + 2], data[i + 1], data[i + 0], data[i + 3] * 1000));
        min_t = std::min(min_t, (uint32_t)data[i + 3] * 1000);
        max_t = std::max(max_t, (uint32_t)data[i + 3] * 1000);
    }

    std::vector<std::vector<Event> > ret;

    int rpos = 0, e_num = events.size();
    for (int i = 0; i < e_num; ++i)
        events[i].ts = (events[i].ts - min_t) / time_window;

    sort(events.begin(), events.end());

    for (int l = 0; l < e_num; l = rpos) {
        std::vector<Event> tmp;
        if (!include_incomplete &&
            events[l].ts == (max_t - min_t) / time_window)
            break;
        while (rpos < e_num && events[rpos].ts == events[l].ts) {
            tmp.push_back(events[rpos]);
            ++rpos;
        }
        ret.push_back(tmp);
    }

    return ret;
}

int DVSGesture::size() {
    return id_path.size();
}


CIFAR10DVS::CIFAR10DVS(const std::string& dataset_root,
               int time_window,
               bool include_incomplete)
    : dataset_root(dataset_root),
      time_window(time_window),
      include_incomplete(include_incomplete) {
    namespace fs = std::filesystem;
    int cor = 0, tot = 0;
    try {
        for (int i = 0; i < 10; ++i) {
            std::string sub_file(dataset_root);
            sub_file += '/';
            sub_file += (i + '0');
            if (fs::exists(sub_file) && fs::is_directory(sub_file)) {
                for (const auto& entry :
                     fs::recursive_directory_iterator(sub_file)) {
                    auto filePath = entry.path().string();
                    path_label.push_back(make_pair(filePath, i));
                }
            } else {
                std::cout << "CIFAR10DVS Error: Missing Category !" << std::endl;
            }
        }
        sort(path_label.begin(), path_label.end());
    } catch (const fs::filesystem_error& e) {
        std::cerr << e.what() << std::endl;
    }
    return;
}

std::vector<std::vector<Event> > CIFAR10DVS::get_binned_input(int id) {
    return read_cifar10dvs_event(path_label[id].first, time_window, include_incomplete);
}

int CIFAR10DVS::get_label(int id) {
    return path_label[id].second;
}

std::vector<std::vector<Event> > CIFAR10DVS::read_cifar10dvs_event(const std::string &filePath, int time_window, bool include_incomplete) {
    uint32_t min_t = 1000000000, max_t = 0;

    cnpy::NpyArray arr = cnpy::npy_load(filePath);

    std::vector<Event> events;


    // // 访问数组数据
    // double* data = arr.data<double>();
    // size_t numElements = arr.num_vals;
    // for (size_t i = 0; i < numElements; i += 4) {
    //     events.push_back(
    //         Event(data[i + 2], data[i + 1], data[i + 0], data[i + 3] * 1000));
    //     min_t = std::min(min_t, (uint32_t)data[i + 3] * 1000);
    //     max_t = std::max(max_t, (uint32_t)data[i + 3] * 1000);
    // }

    int* data = arr.data<int>();
    size_t numElements = arr.num_vals;
    for (size_t i = 0; i < numElements; i += 4) {
        events.push_back(
            Event((int)data[i + 3], (int)data[i + 2], (int)data[i + 1], (int)data[i + 0]));
        min_t = std::min(min_t, (uint32_t)data[i + 0]);
        max_t = std::max(max_t, (uint32_t)data[i + 0]);
    }

    std::vector<std::vector<Event> > ret;

    int rpos = 0, e_num = events.size();
    for (int i = 0; i < e_num; ++i)
        events[i].ts = (events[i].ts - min_t) / time_window;

    sort(events.begin(), events.end());

    for (int l = 0; l < e_num; l = rpos) {
        std::vector<Event> tmp;
        if (!include_incomplete &&
            events[l].ts == (max_t - min_t) / time_window)
            break;
        while (rpos < e_num && events[rpos].ts == events[l].ts) {
            tmp.push_back(events[rpos]);
            ++rpos;
        }
        ret.push_back(tmp);
    }

    return ret;
}

int CIFAR10DVS::size() {
    return path_label.size();
}
