#include "../../src/utils/utils.h"
#include "../../src/disjointadabf/disjointadabf.h"
#include "../../src/disjointadabf/calibration_manager.h"
#include "../../src/disjointadabf/burger_config.h"
#include "../utils/str2uint64_hash.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <unistd.h>

void construct_disjointadabf(
    const std::vector<std::string>& all_pos_key_str,
    const std::vector<std::vector<float>>& all_pos_X,
    const std::vector<std::string>& X_val_key,
    const std::vector<std::vector<float>>& X_val,
    const std::vector<std::vector<float>>& y_val,
    std::string trained_xgboost_folder_path,
    std::string disjointadabf_folder_path,
    float bit_size_of_Ada_BF
) {
    // If disjointadabf_folder_path does not exist, create it
    if (!std::filesystem::exists(disjointadabf_folder_path)) {
        std::filesystem::create_directories(disjointadabf_folder_path);
    }

    std::string xgboost_model_path = trained_xgboost_folder_path + "/xgboost_model.bin";
    std::string xgboost_info_path = trained_xgboost_folder_path + "/info.json";
    std::string disjointadabf_model_path = disjointadabf_folder_path + "/disjointadabf_model.bin";
    std::string disjointadabf_info_path = disjointadabf_folder_path + "/info.json";
    std::string disjointadabf_config_path = disjointadabf_folder_path + "/config.json";

    // Construct DisjointAdaBF
    DisjointAdaBF disjointadabf;
    BurgerConfig burger_config;
    float prepare_time_ms = 0;
    float calibration_time_ms = 0;
    float configuration_time_ms = 0;
    float add_keys_time_ms = 0;
    float model_construct_time_ms = 0;
    {
        auto construction_start = std::chrono::high_resolution_clock::now();
        // prepare positive samples
        auto prepare_start = std::chrono::high_resolution_clock::now();
        std::vector<uint64_t> all_pos_key;
        for (int i = 0; i < (int)all_pos_key_str.size(); i++) {
            all_pos_key.push_back(str2uint64_hash(all_pos_key_str[i]));
        }
        std::vector<uint64_t> X_val_key_pos;
        std::vector<uint64_t> X_val_key_neg;
        std::vector<std::vector<float>> X_val_pos;
        std::vector<std::vector<float>> X_val_neg;
        for (int i = 0; i < (int)X_val.size(); i++) {
            if (y_val[i][0] == 1) {
                X_val_key_pos.push_back(str2uint64_hash(X_val_key[i]));
                X_val_pos.push_back(X_val[i]);
            } else {
                X_val_key_neg.push_back(str2uint64_hash(X_val_key[i]));
                X_val_neg.push_back(X_val[i]);
            }
        }
        auto prepare_end = std::chrono::high_resolution_clock::now();

        // Calibrate and add keys
        disjointadabf.load_trained_xgboost_model(xgboost_model_path);

        if (bit_size_of_Ada_BF <= disjointadabf.xgboost_model_size_kb * 1024 * 8) {
            std::cerr << "Error: Bit size of AdaBF must be greater than the size of the XGBoost model.\n";
            exit(2);
        }

        Calibration_Manager calibration_manager(disjointadabf);
        auto calibration_start = std::chrono::high_resolution_clock::now();
        calibration_manager.calibrate(X_val, y_val);
        auto calibration_end = std::chrono::high_resolution_clock::now();
        auto configure_start = std::chrono::high_resolution_clock::now();
        disjointadabf = calibration_manager.configure_burger_assembly(all_pos_X.size(), bit_size_of_Ada_BF);
        auto configure_end = std::chrono::high_resolution_clock::now();

        // add keys
        std::cout << "[INFO] Adding keys to DisjointAdaBF..." << std::endl;
        auto add_keys_start = std::chrono::high_resolution_clock::now();
        disjointadabf.add_all(all_pos_X, all_pos_key);
        auto add_keys_end = std::chrono::high_resolution_clock::now();
        std::cout << "[INFO] Done adding keys to DisjointAdaBF." << std::endl;

        auto construction_end = std::chrono::high_resolution_clock::now();
        model_construct_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(construction_end - construction_start).count();
        prepare_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(prepare_end - prepare_start).count();
        calibration_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(calibration_end - calibration_start).count();
        configuration_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(configure_end - configure_start).count();
        add_keys_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(add_keys_end - add_keys_start).count();

        // Check FNR == 0
        {
            std::vector<bloomfilter::Status> predictions_ = disjointadabf.contains_all(X_val_pos, X_val_key_pos);
            int false_negatives = 0;
            for (int i = 0; i < (int)predictions_.size(); i++) {
                if (predictions_[i] == bloomfilter::Status::NotFound) {
                    false_negatives++;
                }
            }
            if (false_negatives > 0) {
                std::cerr << "Error: False negatives found in the positive samples.\n";
                exit(1);
            }
        }

        // Check FPR ~ F
        {
            std::vector<bloomfilter::Status> predictions_ = disjointadabf.contains_all(X_val_neg, X_val_key_neg);
            int false_positives = 0;
            for (int i = 0; i < (int)predictions_.size(); i++) {
                if (predictions_[i] == bloomfilter::Status::Ok) {
                    false_positives++;
                }
            }
            float FPR = (float)false_positives / (float)X_val_neg.size();
            std::cout << "[INFO] FPR: " << FPR << std::endl;
        }

        burger_config = calibration_manager.get_burger_config();
    }

    // Save DisjointAdaBF
    float model_size_kb = 0;
    {
        disjointadabf.save_model(disjointadabf_model_path);
        model_size_kb = getBinFileSize(disjointadabf_model_path);
    }

    // Save info.json
    {
        std::ofstream ofs(disjointadabf_info_path);
        ofs << "{\n";
        ofs << "  \"model_type\": \"disjointadabf\",\n";
        ofs << "  \"bit_size_of_Ada_BF\": " << bit_size_of_Ada_BF << ",\n";
        ofs << "  \"trained_xgboost_folder_path\": \"" << trained_xgboost_folder_path << "\",\n";
        ofs << "  \"model_construct_time_ms\": " << model_construct_time_ms << ",\n";
        ofs << "  \"prepare_time_ms\": " << prepare_time_ms << ",\n";
        ofs << "  \"calibration_time_ms\": " << calibration_time_ms << ",\n";
        ofs << "  \"configuration_time_ms\": " << configuration_time_ms << ",\n";
        ofs << "  \"add_keys_time_ms\": " << add_keys_time_ms << ",\n";
        ofs << "  \"model_size_kb\": " << model_size_kb << "\n";
        ofs << "}\n";
        ofs.close();
    }

    auto vector_to_json = [](const std::vector<float>& v, int precision = 6) {
        std::ostringstream s;
        s << "[";
        for (int i = 0; i < (int)v.size(); i++) {
            if (std::isnan(v[i])) {
                s << "null";
            } else if (std::isinf(v[i])) {
                s << (v[i] > 0 ? "Infinity" : "-Infinity");
            } else {
                s << std::fixed << std::setprecision(precision) << v[i];
            }
            if (i < (int)v.size() - 1) {
                s << ", ";
            }
        }
        s << "]";
        return s.str();
    };

    // Save config.json
    {
        std::ofstream ofs(disjointadabf_config_path);
        ofs << "{\n";
        ofs << "  \"d\": " << burger_config.d << ",\n";
        ofs << "  \"k\": " << burger_config.k << ",\n";
        ofs << "  \"n\": " << burger_config.n << ",\n";
        ofs << "  \"f\": " << vector_to_json(burger_config.f) << ",\n";
        ofs << "  \"th_f\": " << vector_to_json(burger_config.th_f) << ",\n";
        ofs << "  \"g_f\": " << vector_to_json(burger_config.g_f) << "\n";
        ofs << "}\n";
    }
}


int main(int argc, char* argv[]) {
    std::string all_pos_key_path;
    std::string all_pos_X_path;
    std::string X_val_key_path;
    std::string X_val_path;
    std::string y_val_path;
    std::string trained_xgboost_folder_path;
    std::string disjointadabf_folder_path;
    float bit_size_of_Ada_BF = -1.0;

    int opt;
    while ((opt = getopt(argc, argv, "p:q:k:x:y:t:o:b:")) != -1) {
        switch (opt) {
            case 'p':
                all_pos_key_path = optarg;
                break;
            case 'q':
                all_pos_X_path = optarg;
                break;
            case 'k':
                X_val_key_path = optarg;
                break;
            case 'x':
                X_val_path = optarg;
                break;
            case 'y':
                y_val_path = optarg;
                break;
            case 't':
                trained_xgboost_folder_path = optarg;
                break;
            case 'o':
                disjointadabf_folder_path = optarg;
                break;
            case 'b':
                bit_size_of_Ada_BF = std::stof(optarg);
                break;
            default:
                std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -q <all_pos_X_path> -k <X_val_key_path> -x <X_val_path> -y <y_val_path> -t <trained_xgboost_folder_path> -o <disjointadabf_folder_path> -b <bit_size_of_Ada_BF>\n";
                return 1;
        }
    }

    // Check if all required arguments are provided
    if (all_pos_key_path.empty() || all_pos_X_path.empty() || X_val_key_path.empty() || X_val_path.empty() || y_val_path.empty() || trained_xgboost_folder_path.empty() || disjointadabf_folder_path.empty() || bit_size_of_Ada_BF < 0) {
        std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -q <all_pos_X_path> -k <X_val_key_path> -x <X_val_path> -y <y_val_path> -t <trained_xgboost_folder_path> -o <disjointadabf_folder_path> -b <bit_size_of_Ada_BF>\n";
        return 1;
    }

    // If disjointadabf_model_path already exists, skip constructing the disjointadabf
    std::string disjointadabf_model_path = disjointadabf_folder_path + "/disjointadabf_model.bin";
    if (std::filesystem::exists(disjointadabf_model_path)) {
        std::cout << "DisjointAdaBF model already exists at " << disjointadabf_model_path << ". Skipping constructing the DisjointAdaBF.\n";
        return 0;
    }

    // Ensure the processing code is executed
    std::vector<std::string> all_pos_key = load_key(all_pos_key_path);
    std::vector<std::vector<float>> all_pos_X = load_data(all_pos_X_path);
    std::vector<std::string> X_val_key = load_key(X_val_key_path);
    std::vector<std::vector<float>> X_val = load_data(X_val_path);
    std::vector<std::vector<float>> y_val = load_data(y_val_path);

    construct_disjointadabf(all_pos_key, all_pos_X, X_val_key, X_val, y_val, trained_xgboost_folder_path, disjointadabf_folder_path, bit_size_of_Ada_BF);

    return 0;
}
