#include "../../src/utils/utils.h"
#include "../../src/plbf/plbf.h"
#include "../../src/plbf/calibration_manager.h"
#include "../utils/str2uint64_hash.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <random>
#include <chrono>
#include <algorithm>
#include <unistd.h>

void test_plbf(
    const std::vector<std::vector<float>>& X_test,
    const std::vector<std::vector<float>>& y_test,
    const std::vector<std::string>& X_test_key,
    std::string plbf_folder_path,
    std::string result_folder_path,
    float pos_query_ratio,
    int query_num
) {
    // If result_folder_path does not exist, create it
    if (!std::filesystem::exists(result_folder_path)) {
        std::filesystem::create_directories(result_folder_path);
    }

    std::string plbf_model_path = plbf_folder_path + "/plbf_model.bin";
    std::string result_json_path = result_folder_path + "/result.json";

    // load model
    PLBF plbf;
    {
        plbf.load_model(plbf_model_path);
    }

    // prepare test data
    std::vector<std::vector<float>> sampled_X_test;
    std::vector<std::vector<float>> sampled_y_test;
    std::vector<uint64_t> sampled_X_test_key;
    {
        int pos_query_num = pos_query_ratio * query_num;
        int neg_query_num = query_num - pos_query_num;

        std::vector<std::vector<float>> X_test_pos;
        std::vector<std::vector<float>> X_test_neg;
        std::vector<uint64_t> X_test_key_pos;
        std::vector<uint64_t> X_test_key_neg;

        for (int i = 0; i < (int)X_test.size(); i++) {
            if (y_test[i][0] == 1) {
                X_test_pos.push_back(X_test[i]);
                X_test_key_pos.push_back(str2uint64_hash(X_test_key[i]));
            } else {
                X_test_neg.push_back(X_test[i]);
                X_test_key_neg.push_back(str2uint64_hash(X_test_key[i]));
            }
        }

        if (pos_query_num > (int)X_test_pos.size() || neg_query_num > (int)X_test_neg.size()) {
            std::cerr << "Error: Not enough positive or negative queries in the test data.\n";
            return;
        }

        // randomly sample positive and negative queries (duplicate not allowed)
        std::mt19937 g(42);
        std::vector<int> pos_indices(X_test_pos.size());
        std::iota(pos_indices.begin(), pos_indices.end(), 0);
        std::shuffle(pos_indices.begin(), pos_indices.end(), g);
        std::vector<int> neg_indices(X_test_neg.size());
        std::iota(neg_indices.begin(), neg_indices.end(), 0);
        std::shuffle(neg_indices.begin(), neg_indices.end(), g);
        for (int i = 0; i < pos_query_num; i++) {
            sampled_X_test.push_back(X_test_pos[pos_indices[i]]);
            sampled_y_test.push_back({1});
            sampled_X_test_key.push_back(X_test_key_pos[pos_indices[i]]);
        }
        for (int i = 0; i < neg_query_num; i++) {
            sampled_X_test.push_back(X_test_neg[neg_indices[i]]);
            sampled_y_test.push_back({0});
            sampled_X_test_key.push_back(X_test_key_neg[neg_indices[i]]);
        }

        // shuffle the sampled data
        std::vector<int> indices(sampled_X_test.size());
        std::iota(indices.begin(), indices.end(), 0);
        std::shuffle(indices.begin(), indices.end(), g);
        std::vector<std::vector<float>> shuffled_sampled_X_test(sampled_X_test.size());
        std::vector<std::vector<float>> shuffled_sampled_y_test(sampled_y_test.size());
        std::vector<uint64_t> shuffled_sampled_X_test_key(sampled_X_test_key.size());

        for (int i = 0; i < (int)indices.size(); i++) {
            shuffled_sampled_X_test[i] = sampled_X_test[indices[i]];
            shuffled_sampled_y_test[i] = sampled_y_test[indices[i]];
            shuffled_sampled_X_test_key[i] = sampled_X_test_key[indices[i]];
        }

        sampled_X_test = shuffled_sampled_X_test;
        sampled_y_test = shuffled_sampled_y_test;
        sampled_X_test_key = shuffled_sampled_X_test_key;
    }

    // warm-up
    const int warmup_num = 10;
    const int test_num = 10;
    for (int i = 0; i < warmup_num; i++) {
        std::vector<bloomfilter::Status> predictions_ = plbf.contains_all(sampled_X_test, sampled_X_test_key);
    }

    // test
    int query_pos_num = 0;
    int query_neg_num = 0;
    int false_positives = 0;
    int false_negatives = 0;
    float fpr = 1.0;
    float fnr = 1.0;
    float total_test_time = 0.0;
    std::vector<float> test_time_list;
    float test_time = 0.0;
    for (int i = 0; i < test_num; i++) {
        auto start = std::chrono::high_resolution_clock::now();
        std::vector<bloomfilter::Status> predictions_ = plbf.contains_all(sampled_X_test, sampled_X_test_key);
        auto end = std::chrono::high_resolution_clock::now();
        std::vector<bool> predictions(predictions_.size());
        for (int i = 0; i < (int)predictions_.size(); i++) {
            predictions[i] = predictions_[i] == bloomfilter::Status::Ok;
        }

        fpr = calculate_false_positive_rate(predictions, sampled_y_test);
        fnr = calculate_false_negative_rate(predictions, sampled_y_test);
        query_pos_num = std::count(sampled_y_test.begin(), sampled_y_test.end(), std::vector<float>({1}));
        query_neg_num = std::count(sampled_y_test.begin(), sampled_y_test.end(), std::vector<float>({0}));
        false_positives = (int)(fpr * query_neg_num + 0.5);
        false_negatives = (int)(fnr * query_pos_num + 0.5);
        total_test_time += (float)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000;
        test_time_list.push_back((float)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000);
    }
    test_time = total_test_time / test_num;
    std::cout << "Total test time: " << total_test_time << " ms\n";
    std::cout << "Average test time: " << test_time << " ms\n";
    std::cout << "Test time per key: " << test_time / query_num * 1000000 << " ns\n";

    // save to result.json
    {
        // auto vector_to_json = [](const std::vector<int>& v) {
        //     std::ostringstream s;
        //     s << "[";
        //     for (int i = 0; i < (int)v.size(); i++) {
        //         if (std::isnan(v[i])) {
        //             s << "null";
        //         } else if (std::isinf(v[i])) {
        //             s << (v[i] > 0 ? "Infinity" : "-Infinity");
        //         } else {
        //             s << v[i];
        //         }
        //         if (i < (int)v.size() - 1) {
        //             s << ", ";
        //         }
        //     }
        //     s << "]";
        //     return s.str();
        // };

        auto vector_to_json_float = [](const std::vector<float>& v) {
            std::ostringstream s;
            s << "[";
            for (int i = 0; i < (int)v.size(); i++) {
                if (std::isnan(v[i])) {
                    s << "null";
                } else if (std::isinf(v[i])) {
                    s << (v[i] > 0 ? "Infinity" : "-Infinity");
                } else {
                    s << setprecision(20) << v[i];
                }
                if (i < (int)v.size() - 1) {
                    s << ", ";
                }
            }
            s << "]";
            return s.str();
        };

        std::ofstream ofs(result_json_path);
        ofs << "{\n";
        ofs << "  \"bf_model_folder_path\": \"" << plbf_folder_path << "\",\n";
        ofs << "  \"pos_query_ratio\": " << pos_query_ratio << ",\n";
        ofs << "  \"query_num\": " << query_num << ",\n";
        ofs << "  \"fpr\": " << fpr << ",\n";
        ofs << "  \"fnr\": " << fnr << ",\n";
        ofs << "  \"query_pos_num\": " << query_pos_num << ",\n";
        ofs << "  \"query_neg_num\": " << query_neg_num << ",\n";
        ofs << "  \"false_positives\": " << false_positives << ",\n";
        ofs << "  \"false_negatives\": " << false_negatives << ",\n";
        ofs << "  \"test_time_ms\": " << test_time << ",\n";
        ofs << "  \"test_time_list_ms\": " << vector_to_json_float(test_time_list) << "\n";
        ofs << "}\n";
        ofs.close();
    }
}

int main(int argc, char* argv[]) {
    std::string X_test_key_path;
    std::string X_test_path;
    std::string y_test_path;
    std::string plbf_folder_path;
    std::string result_folder_path;
    float pos_query_ratio = -1.0;
    int query_num = -1;

    int opt;
    while ((opt = getopt(argc, argv, "k:x:y:m:o:p:q:")) != -1) {
        switch (opt) {
            case 'k':
                X_test_key_path = optarg;
                break;
            case 'x':
                X_test_path = optarg;
                break;
            case 'y':
                y_test_path = optarg;
                break;
            case 'm':
                plbf_folder_path = optarg;
                break;
            case 'o':
                result_folder_path = optarg;
                break;
            case 'p':
                pos_query_ratio = std::atof(optarg);
                break;
            case 'q':
                query_num = std::atoi(optarg);
                break;
            default:
                std::cerr << "Usage: " << argv[0] << " -k <X_test_key_path> -x <X_test_path> -y <y_test_path> -m <plbf_folder_path> -o <result_folder_path> -p <pos_query_ratio>\n";
                return 1;
        }
    }

    // Check if all required arguments are provided
    if (X_test_key_path.empty() || X_test_path.empty() || y_test_path.empty() || plbf_folder_path.empty() || result_folder_path.empty() || pos_query_ratio == -1.0 || query_num == -1) {
        std::cerr << "Error: Missing required arguments.\n";
        std::cerr << "Usage: " << argv[0] << " -k <X_test_key_path> -x <X_test_path> -y <y_test_path> -m <plbf_folder_path> -o <result_folder_path> -p <pos_query_ratio>\n";
        return 1;
    }

    // If result_json_path already exists, skip testing the PLBF
    std::string result_json_path = result_folder_path + "/result.json";
    if (std::filesystem::exists(result_json_path)) {
        std::cout << "Result JSON already exists at " << result_json_path << ". Skipping testing the PLBF.\n";
        return 0;
    }

    std::vector<std::string> X_test_key = load_key(X_test_key_path);
    std::vector<std::vector<float>> X_test = load_data(X_test_path);
    std::vector<std::vector<float>> y_test = load_data(y_test_path);
    test_plbf(X_test, y_test, X_test_key, plbf_folder_path, result_folder_path, pos_query_ratio, query_num);

    return 0;
}
