#include "learned_sort_using_learned_index/learned_sort_using_learned_index.h"
#include "counter.h"
#include "create_input_vector.h"
#include <iostream>
#include <vector>
#include <set>
#include <random>
#include <string>
#include <algorithm>
#include <cmath>
#include <cassert>
#include <iomanip>
#include <unordered_map>

#include <chrono>

template <typename T>
std::unordered_map<std::string, size_t> measure_count_dict(std::vector<T> x, char *argv[]) {
    std::unordered_map<std::string, size_t> count_dict;
    std::string method_name = argv[1];
    // std::multiset<T> s(x.begin(), x.end());
    counter::counter.clear();
    if (method_name == "merge_sort") {
        merge_sort::merge_sort(x.begin(), x.end());
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_binary_search") {
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "binary_search");
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_btree") {
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "btree");
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_btree_approx") {
        size_t btree_epsilon = stoi(argv[5]);
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "btree_approx", btree_epsilon);
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_espc") {
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "espc");
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_pgm") {
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "pgm");
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_pgm_approx") {
        size_t pgm_epsilon = stoi(argv[5]);
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "pgm_approx", pgm_epsilon);
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_rmi") {
        std::shuffle(x.begin(), x.end(), std::mt19937(std::random_device()()));
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "rmi");
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        for (const auto& pair : counter::counter.get_total_counts()) {
            count_dict[pair.first] = pair.second;
        }
    } else {
        std::cerr << "Unknown method: " << method_name << std::endl;
        return count_dict;
    }
    // std::multiset<T> s2(x.begin(), x.end());
    // if (s != s2) {
    //     std::cerr << "Error: The set of unique elements is different." << std::endl;
    //     exit(1);
    // }
    return count_dict;
}

std::unordered_map<std::string, std::vector<size_t>> measure_sort_count_dict(
    char *argv[]
) {
    long n = std::stol(argv[2]);
    std::string distribution = argv[3];
    long iterations = std::stol(argv[4]);
    std::unordered_map<std::string, std::vector<size_t>> count_list_dict;
    for (long i = 0; i < iterations; i++) {
        std::unordered_map<std::string, size_t> count_dict;
        std::vector<double> x = create_input_vector(n, distribution, (unsigned int)i);
        count_dict = measure_count_dict(x, argv);
        for (auto const& pair : count_dict) {
            count_list_dict[pair.first].push_back(pair.second);
        }
    }
    return count_list_dict;
}

int main(int argc, char *argv[]) {
    if (argc == 1) {
        std::cerr << "Usage: " << argv[0] << " {method_name} ..." << std::endl;
        return 1;
    }

    std::string method_name = argv[1];

    if (method_name == "learned_sort_using_learned_index_binary_search") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_binary_search {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_btree") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_btree {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_btree_approx") {
        if (argc != 6) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_btree_approx {n} {distribution} {iteration} {epsilon}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_espc") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_espc {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_pgm") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_pgm {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_pgm_approx") {
        if (argc != 6) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_pgm_approx {n} {distribution} {iteration} {epsilon}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_rmi") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_rmi {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else {
        std::cerr << "Unknown method !: " << method_name << std::endl;
        return 1;
    }

    std::unordered_map<std::string, std::vector<size_t>> count_list_dict = measure_sort_count_dict(argv);
    std::cout << "{" << std::endl;
    for (auto it = count_list_dict.begin(); it != count_list_dict.end(); ++it) {
        const auto& pair = *it;
        std::cout << "  \"" << pair.first << "\": [" << std::endl;
        for (long i = 0; i < (long)pair.second.size(); i++) {
            if (i == (long)pair.second.size() - 1) {
                std::cout << "    " << std::setprecision(20) << pair.second[i] << std::endl;
            } else {
                std::cout << "    " << std::setprecision(20) << pair.second[i] << "," << std::endl;
            }
        }
        if (std::next(it) == count_list_dict.end()) {
            std::cout << "  ]" << std::endl;
        } else {
            std::cout << "  ]," << std::endl;
        }
    }
    std::cout << "}" << std::endl;

    return 0;
}

