#ifndef __CREATE_INPUT_VECTOR_H__
#define __CREATE_INPUT_VECTOR_H__

#include <vector>
#include <random>
#include <string>
#include <algorithm>
#include <cmath>
#include <iostream>
#include <fstream>
#include <iterator>
#include <cstdint>
#include <set>
#include "./generators.h"

static std::vector<double> loaded_data;

bool file_exists(const std::string& path) {
    std::ifstream in(path, std::ios::binary);
    return in.good();
}

std::string make_cache_filename(const std::string& base, long n, unsigned int seed) {
    // e.g., "data/books_800M_uint64" -> "data/books_800M_uint64_n1000_seed0"
    return base + "_n" + std::to_string(n) + "_seed" + std::to_string(seed);
}

std::vector<double> load_double_binary(const std::string& filename) {
    std::ifstream in(filename, std::ios::binary);
    if (!in.is_open()) {
        std::cerr << "unable to open " << filename << std::endl;
        exit(EXIT_FAILURE);
    }
    uint64_t size = 0;
    in.read(reinterpret_cast<char*>(&size), sizeof(uint64_t));
    std::vector<double> v(size);
    if (size) {
        in.read(reinterpret_cast<char*>(v.data()), sizeof(double) * size);
    }
    in.close();
    return v;
}

void save_double_binary(const std::string& filename, const std::vector<double>& v) {
    std::ofstream out(filename, std::ios::binary);
    if (!out.is_open()) {
        std::cerr << "unable to open for write " << filename << std::endl;
        exit(EXIT_FAILURE);
    }
    uint64_t size = static_cast<uint64_t>(v.size());
    out.write(reinterpret_cast<const char*>(&size), sizeof(uint64_t));
    if (size) {
        out.write(reinterpret_cast<const char*>(v.data()), sizeof(double) * size);
    }
    out.close();
}

std::vector<double> sample_from_loaded(long n, std::mt19937& mt) {
    if (loaded_data.empty()) {
        std::cerr << "loaded_data is empty. Call load_data(base_path) first." << std::endl;
        exit(EXIT_FAILURE);
    }
    if (static_cast<size_t>(n) > loaded_data.size()) {
        std::cerr << "requested n > population size in loaded_data" << std::endl;
        exit(EXIT_FAILURE);
    }
    std::vector<double> v;
    v.reserve(n);
    std::sample(loaded_data.begin(), loaded_data.end(), std::back_inserter(v), n, mt);
    std::shuffle(v.begin(), v.end(), mt);
    return v;
}

bool n_is_less_than_population_size(const std::string& filename, long n) {
    if (filename == "data/books_800M_uint64" ||
        filename == "data/osm_cellids_800M_uint64") {
        return n <= 800000000;
    } else if (filename == "data/wiki_ts_200M_uint64" ||
               filename == "data/fb_200M_uint64") {
        return n <= 200000000;
    } else if (filename == "data/nyc_pickup" ||
               filename == "data/nyc_dist" ||
               filename == "data/nyc_tot") {
        return n <= 100000000;
    } else if (filename == "data/stks_vol" ||
               filename == "data/stks_open" ||
               filename == "data/stks_date" ||
               filename == "data/stks_low") {
        return n <= 20000000;
    } else if (filename == "data/sof_hum" ||
               filename == "data/sof_press" ||
               filename == "data/sof_temp") {
        return n <= 50000000;
    } else if (filename == "data/chic_start" ||
               filename == "data/chic_tot") {
        return n <= 20000000;
    }
    return true;
}

void load_data(const std::string& filename) {
    std::string sosd_files[] = {
        "data/books_800M_uint64",
        "data/osm_cellids_800M_uint64",
        "data/wiki_ts_200M_uint64",
        "data/fb_200M_uint64"
    };
    std::string txt_files[] = {
        "data/nyc_pickup",
        "data/nyc_dist",
        "data/nyc_tot",
        "data/stks_vol",
        "data/stks_open",
        "data/stks_date",
        "data/stks_low",
        "data/sof_hum",
        "data/sof_press",
        "data/sof_temp",
        "data/chic_start",
        "data/chic_tot"
    };

    if (std::find(std::begin(sosd_files), std::end(sosd_files), filename) != std::end(sosd_files)) {
        std::ifstream in(filename, std::ios::binary);
        if (!in.is_open()) {
            std::cerr << "unable to open " << filename << std::endl;
            exit(EXIT_FAILURE);
        }
        uint64_t size = 0;
        in.read(reinterpret_cast<char*>(&size), sizeof(uint64_t));
        std::vector<uint64_t> loaded_data_uint64;
        loaded_data_uint64.resize(size);
        if (size) {
            in.read(reinterpret_cast<char*>(loaded_data_uint64.data()), size * sizeof(uint64_t));
        }
        in.close();
        loaded_data.resize(size);
        std::transform(loaded_data_uint64.begin(), loaded_data_uint64.end(), loaded_data.begin(),
                            [](uint64_t v) { return static_cast<double>(v); });
    } else if (std::find(std::begin(txt_files), std::end(txt_files), filename) != std::end(txt_files)) {
        std::ifstream in(filename);
        if (!in.is_open()) {
            std::cerr << "unable to open " << filename << std::endl;
            exit(EXIT_FAILURE);
        }
        loaded_data.clear();
        loaded_data.reserve(1 << 20);
        std::string line;
        while (std::getline(in, line)) {
            loaded_data.push_back(std::stod(line));
        }
        in.close();
    } else {
        std::cerr << "Unknown filename: " << filename << std::endl;
        exit(EXIT_FAILURE);
    }

    if (filename == "data/fb_200M_uint64") {
        // remove the outliers
        // Calculate 99.999th percentile (0.99999 quantile)
        std::vector<double> sorted_data = loaded_data;
        std::sort(sorted_data.begin(), sorted_data.end());
        
        size_t quantile_index = static_cast<size_t>(std::floor(0.99999 * sorted_data.size()));
        double threshold = sorted_data[quantile_index];
        
        // Remove outliers above the threshold
        std::vector<double> filtered_data;
        filtered_data.reserve(loaded_data.size());
        for (const auto& value : loaded_data) {
            if (value < threshold) {
                filtered_data.push_back(value);
            }
        }
        
        loaded_data = std::move(filtered_data);
    }
}

int get_data_type(std::string distribution) {
    // 0: double, 1: long
    if (
        distribution == "uniform" || distribution == "uniform_shift" ||
        distribution == "normal" || distribution == "normal_shift" ||
        distribution == "exponential" || distribution == "exponential_shift" ||
        distribution == "lognormal" || distribution == "lognormal_shift" ||
        distribution == "chisquared") {
        return 0;
    } else if (distribution.find("data/") == 0) {
        return 0;
    } else if (distribution == "eightdups" || distribution == "modulo" || distribution == "root_dups" || distribution == "two_dups" || distribution == "zipf") {
        return 1;
    } else {
        std::cerr << "Unknown distribution: " << distribution << std::endl;
        exit(1);
    }
}

std::vector<double> create_input_vector_(long n, std::string distribution, unsigned int seed) {
    std::vector<double> x(n);
    std::mt19937 mt{seed};
    if (distribution == "uniform") {
        std::uniform_real_distribution<> dist(0, 1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
    } else if (distribution == "uniform_shift") {
        std::uniform_real_distribution<> dist(0, 1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
        for (int i = 0; i < n; i++) x[i] += i / n;
    } else if (distribution == "normal") {
        std::normal_distribution<> dist(0, 1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
    } else if (distribution == "normal_shift") {
        std::normal_distribution<> dist(0, 1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
        for (int i = 0; i < n; i++) x[i] += i / n;
    } else if (distribution == "exponential") {
        std::exponential_distribution<> dist(1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
    } else if (distribution == "exponential_shift") {
        std::exponential_distribution<> dist(1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
        for (int i = 0; i < n; i++) x[i] += i / n;
    } else if (distribution == "lognormal") {
        std::lognormal_distribution<> dist(0, 1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
    } else if (distribution == "lognormal_shift") {
        std::lognormal_distribution<> dist(0, 1);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
        for (int i = 0; i < n; i++) x[i] += i / n;
    } else if (distribution == "chisquared") {
        std::chi_squared_distribution<> dist(4);
        std::generate(x.begin(), x.end(), [&](){return dist(mt);});
    } else if (distribution.find("data/") == 0) {
        // 2) try cache file first
        const std::string cache_path = make_cache_filename(distribution, n, seed);

        if (file_exists(cache_path)) {
            x = load_double_binary(cache_path);
            if (static_cast<long>(x.size()) != n) {
                // Size mismatch, regenerate
                std::cerr << "[warn] cache size mismatch ("
                            << x.size() << " vs n=" << n << "), regenerating..." << std::endl;
                if (loaded_data.empty()) {
                    load_data(distribution);
                }
                x = sample_from_loaded(n, mt);
                save_double_binary(cache_path, x);
            }
        } else {
            if (!n_is_less_than_population_size(distribution, n)) {
                std::cerr << "requested n > population size" << std::endl;
                exit(1);
            }

            // 3) generate and cache
            if (loaded_data.empty()) {
                load_data(distribution);
            }
            x = sample_from_loaded(n, mt);
            save_double_binary(cache_path, x);
        }
    } else {
        std::cerr << "Unknown distribution: " << distribution << std::endl;
        exit(1);
    }
    return x;
}

bool ends_with(const std::string& s, const std::string& suf) {
    return s.size() >= suf.size() && std::equal(suf.rbegin(), suf.rend(), s.rbegin());
}

std::vector<double> create_input_vector(long n, std::string distribution, unsigned int seed) {
    std::string dist = distribution;
    std::vector<double> x = create_input_vector_(n, dist, seed);
    return x;
}

std::vector<long> create_input_vector_long(long n, std::string distribution, unsigned int seed) {
    std::vector<long> x(n);
    std::mt19937 mt{seed};
    if (distribution == "eightdups") {
        x = eight_dups_distr<long>(n);
        std::shuffle(x.begin(), x.end(), mt);
    } else if (distribution == "modulo") {
        x = modulo_distr<long>(n);
        std::shuffle(x.begin(), x.end(), mt);
    } else if (distribution == "root_dups") {
        x = root_dups_distr<long>(n);
        std::shuffle(x.begin(), x.end(), mt);
    } else if (distribution == "two_dups") {
        x = two_dups_distr<long>(n);
        std::shuffle(x.begin(), x.end(), mt);
    } else if (distribution == "zipf") {
        x = zipf_distr<long>(n);
        std::shuffle(x.begin(), x.end(), mt);
    } else {
        std::cerr << "Unknown distribution: " << distribution << std::endl;
        exit(1);
    }
    return x;
}

#endif