#include "pcf_learned_sort/pcf_learned_sort.h"
#include "learned_sort_using_learned_index/learned_sort_using_learned_index.h"
#include "radix_sort/radix_sort.h"
#include "ips4o/ips4o.hpp"
#include "create_input_vector.h"
#include "learned-sort/include/learned_sort.h"
#include "balanced-learned-sort/include/bls.h"
#include <iostream>
#include <vector>
#include <set>
#include <random>
#include <string>
#include <algorithm>
#include <cmath>
#include <cassert>
#include <iomanip>
#include <unordered_map>
#include <boost/sort/spreadsort/float_sort.hpp>

#include <chrono>

double measure_time(std::vector<double> &x, char *argv[]) {
    std::string method_name = argv[1];
    std::chrono::steady_clock::time_point  start, end;
    if (method_name == "std") {
        start = std::chrono::steady_clock::now();
        std::sort(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "pcf") {
        start = std::chrono::steady_clock::now();
        pcf_learned_sort::pcf_learned_sort(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "radix") {
        start = std::chrono::steady_clock::now();
        radix_sort::radix_sort(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "boost") {
        start = std::chrono::steady_clock::now();
        boost::sort::spreadsort::float_sort(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "ips4o") {
        start = std::chrono::steady_clock::now();
        ips4o::sort(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "learned_AniKristo") {
        start = std::chrono::steady_clock::now();
        learned_sort::sort(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "bls") {
        start = std::chrono::steady_clock::now();
        ls_framework::bls(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "uls") {
        start = std::chrono::steady_clock::now();
        ls_framework::uls(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else if (method_name == "ls21") {
        start = std::chrono::steady_clock::now();
        ls_framework::ls21(x.begin(), x.end());
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        return (double)(std::chrono::duration<double>(end-start).count());
    } else {
        std::cerr << "Unknown method: " << method_name << std::endl;
        return 1;
    }
}

std::unordered_map<std::string, double> measure_time_dict(std::vector<double> x, char *argv[]) {
    std::unordered_map<std::string, double> time_dict;
    std::string method_name = argv[1];
    std::chrono::steady_clock::time_point  start, end;
    // std::multiset<double> s(x.begin(), x.end());
    learned_sort_using_learned_index::timer.clear();
    if (method_name == "learned_sort_using_learned_index_binary_search") {
        start = std::chrono::steady_clock::now();
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "binary_search");
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        time_dict["time"] = (double)(std::chrono::duration<double>(end-start).count());
        for (const auto& pair : learned_sort_using_learned_index::timer.get_total_times()) {
            time_dict[pair.first] = (double)pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_btree") {
        start = std::chrono::steady_clock::now();
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "btree");
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        time_dict["time"] = (double)(std::chrono::duration<double>(end-start).count());
        for (const auto& pair : learned_sort_using_learned_index::timer.get_total_times()) {
            time_dict[pair.first] = (double)pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_espc") {
        start = std::chrono::steady_clock::now();
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "espc");
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        time_dict["time"] = (double)(std::chrono::duration<double>(end-start).count());
        for (const auto& pair : learned_sort_using_learned_index::timer.get_total_times()) {
            time_dict[pair.first] = (double)pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_pgm") {
        start = std::chrono::steady_clock::now();
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "pgm");
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        time_dict["time"] = (double)(std::chrono::duration<double>(end-start).count());
        for (const auto& pair : learned_sort_using_learned_index::timer.get_total_times()) {
            time_dict[pair.first] = (double)pair.second;
        }
    } else if (method_name == "learned_sort_using_learned_index_rmi") {
        start = std::chrono::steady_clock::now();
        learned_sort_using_learned_index::learned_sort_using_learned_index(x.begin(), x.end(), "rmi");
        end = std::chrono::steady_clock::now();
        if (!std::is_sorted(x.begin(), x.end())) {
            std::cerr << "Error: The result is not sorted." << std::endl;
            exit(1);
        }
        time_dict["time"] = (double)(std::chrono::duration<double>(end-start).count());
        for (const auto& pair : learned_sort_using_learned_index::timer.get_total_times()) {
            time_dict[pair.first] = (double)pair.second;
        }
    } else {
        std::cerr << "Unknown method: " << method_name << std::endl;
        return time_dict;
    }
    // std::multiset<double> s2(x.begin(), x.end());
    // if (s != s2) {
    //     std::cerr << "Error: The set of unique elements is different." << std::endl;
    //     exit(1);
    // }
    return time_dict;
}

std::vector<double> measure_sort_time(
    char *argv[]
) {
    long n = std::stol(argv[2]);
    std::string distribution = argv[3];
    long iterations = std::stol(argv[4]);
    std::vector<double> time_list;
    for (long i = 0; i < iterations; i++) {
        std::vector<double> x = create_input_vector(n, distribution, (unsigned int)i);
        time_list.push_back(measure_time(x, argv));
    }
    return time_list;
}

std::unordered_map<std::string, std::vector<double>> measure_sort_time_dict(
    char *argv[]
) {
    long n = std::stol(argv[2]);
    std::string distribution = argv[3];
    long iterations = std::stol(argv[4]);
    std::unordered_map<std::string, std::vector<double>> time_list_dict;
    for (long i = 0; i < iterations; i++) {
        std::vector<double> x = create_input_vector(n, distribution, (unsigned int)i);
        std::unordered_map<std::string, double> time_dict = measure_time_dict(x, argv);
        for (auto const& pair : time_dict) {
            time_list_dict[pair.first].push_back(pair.second);
        }
    }
    return time_list_dict;
}

int main(int argc, char *argv[]) {
    if (argc == 1) {
        std::cerr << "Usage: " << argv[0] << " {method_name} ..." << std::endl;
        return 1;
    }

    std::string method_name = argv[1];

    if (method_name == "std") {
        if (argc != 5){
            std::cerr << "Usage: " << argv[0] << " std {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "pcf") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " pcf {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_binary_search") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_binary_search {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_btree") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_btree {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_espc") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_espc {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_pgm") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_pgm {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_sort_using_learned_index_rmi") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_sort_using_learned_index_rmi {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "radix") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " radix {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "boost") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " boost {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "ips4o") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " ips4o {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "learned_AniKristo") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " learned_AniKristo {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "bls") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " bls {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "uls") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " uls {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else if (method_name == "ls21") {
        if (argc != 5) {
            std::cerr << "Usage: " << argv[0] << " ls21 {n} {distribution} {iteration}" << std::endl;
            return 1;
        }
    } else {
        std::cerr << "Unknown method: " << method_name << std::endl;
        return 1;
    }

    if (std::string(argv[1]).find("learned_sort_using_learned_index") != std::string::npos) {
        std::unordered_map<std::string, std::vector<double>> time_list_dict = measure_sort_time_dict(argv);
        std::cout << "{" << std::endl;
        for (auto it = time_list_dict.begin(); it != time_list_dict.end(); ++it) {
            const auto& pair = *it;
            std::cout << "  \"" << pair.first << "\": [" << std::endl;
            for (long i = 0; i < (long)pair.second.size(); i++) {
                if (i == (long)pair.second.size() - 1) {
                    std::cout << "    " << std::setprecision(20) << pair.second[i] << std::endl;
                } else {
                    std::cout << "    " << std::setprecision(20) << pair.second[i] << "," << std::endl;
                }
            }
            if (std::next(it) == time_list_dict.end()) {
                std::cout << "  ]" << std::endl;
            } else {
                std::cout << "  ]," << std::endl;
            }
        }
        std::cout << "}" << std::endl;
    } else {
        std::vector<double> time_list = measure_sort_time(argv);
        std::cout << "{" << std::endl;
        std::cout << "  \"time\": [" << std::endl;
        for (long i = 0; i < (long)time_list.size(); i++) {
            if (i == (long)time_list.size() - 1) {
                std::cout << "    " << std::setprecision(20) << time_list[i] << std::endl;
            } else {
                std::cout << "    " << std::setprecision(20) << time_list[i] << "," << std::endl;
            }
        }
        std::cout << "  ]" << std::endl;
        std::cout << "}" << std::endl;
    }

    return 0;
}

