#include "../../src/utils/utils.h"
#include "../../src/clbf/clbf.h"
#include "../../src/clbf/calibration_manager.h"
#include "../../src/clbf/burger_config.h"
#include "../utils/str2uint64_hash.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <unistd.h>

void construct_clbf(
    const std::vector<std::string>& all_pos_key_str,
    const std::vector<std::vector<float>>& all_pos_X,
    const std::vector<std::string>& X_val_key,
    const std::vector<std::vector<float>>& X_val,
    const std::vector<std::vector<float>>& y_val,
    std::string trained_xgboost_folder_path,
    std::string clbf_folder_path,
    float F, float lambda_, float mu_
) {
    // If clbf_folder_path does not exist, create it
    if (!std::filesystem::exists(clbf_folder_path)) {
        std::filesystem::create_directories(clbf_folder_path);
    }

    std::string xgboost_model_path = trained_xgboost_folder_path + "/xgboost_model.bin";
    std::string xgboost_info_path = trained_xgboost_folder_path + "/info.json";
    std::string clbf_model_path = clbf_folder_path + "/clbf_model.bin";
    std::string clbf_info_path = clbf_folder_path + "/info.json";
    std::string clbf_config_path = clbf_folder_path + "/config.json";

    // Construct CLBF
    CLBF clbf;
    BurgerConfig burger_config;
    float prepare_time_ms = 0;
    float calibration_time_ms = 0;
    float configuration_time_ms = 0;
    float add_keys_time_ms = 0;
    float model_construct_time_ms = 0;
    {
        auto construction_start = std::chrono::high_resolution_clock::now();
        // prepare positive samples
        auto prepare_start = std::chrono::high_resolution_clock::now();
        std::vector<uint64_t> all_pos_key;
        for (int i = 0; i < (int)all_pos_key_str.size(); i++) {
            all_pos_key.push_back(str2uint64_hash(all_pos_key_str[i]));
        }
        std::vector<uint64_t> X_val_key_pos;
        std::vector<uint64_t> X_val_key_neg;
        std::vector<std::vector<float>> X_val_pos;
        std::vector<std::vector<float>> X_val_neg;
        for (int i = 0; i < (int)X_val.size(); i++) {
            if (y_val[i][0] == 1) {
                X_val_key_pos.push_back(str2uint64_hash(X_val_key[i]));
                X_val_pos.push_back(X_val[i]);
            } else {
                X_val_key_neg.push_back(str2uint64_hash(X_val_key[i]));
                X_val_neg.push_back(X_val[i]);
            }
        }
        auto prepare_end = std::chrono::high_resolution_clock::now();

        // Calibrate and add keys
        clbf.load_trained_xgboost_model(xgboost_model_path);
        Calibration_Manager calibration_manager(clbf);
        auto calibration_start = std::chrono::high_resolution_clock::now();
        calibration_manager.calibrate(X_val, y_val);
        auto calibration_end = std::chrono::high_resolution_clock::now();
        auto configure_start = std::chrono::high_resolution_clock::now();
        clbf = calibration_manager.configure_burger_assembly(all_pos_X.size(), F, lambda_, mu_);
        auto configure_end = std::chrono::high_resolution_clock::now();

        // add keys
        std::cout << "[INFO] Adding keys to CLBF..." << std::endl;
        auto add_keys_start = std::chrono::high_resolution_clock::now();
        clbf.add_all(all_pos_X, all_pos_key);
        auto add_keys_end = std::chrono::high_resolution_clock::now();
        std::cout << "[INFO] Done adding keys to CLBF." << std::endl;

        auto construction_end = std::chrono::high_resolution_clock::now();
        model_construct_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(construction_end - construction_start).count();
        prepare_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(prepare_end - prepare_start).count();
        calibration_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(calibration_end - calibration_start).count();
        configuration_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(configure_end - configure_start).count();
        add_keys_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(add_keys_end - add_keys_start).count();

        // Check FNR == 0
        {
            std::vector<bloomfilter::Status> predictions_ = clbf.contains_all(X_val_pos, X_val_key_pos);
            int false_negatives = 0;
            for (int i = 0; i < (int)predictions_.size(); i++) {
                if (predictions_[i] == bloomfilter::Status::NotFound) {
                    false_negatives++;
                }
            }
            if (false_negatives > 0) {
                std::cerr << "Error: False negatives found in the positive samples.\n";
                exit(1);
            }
        }

        // Check FPR ~ F
        {
            std::vector<bloomfilter::Status> predictions_ = clbf.contains_all(X_val_neg, X_val_key_neg);
            int false_positives = 0;
            for (int i = 0; i < (int)predictions_.size(); i++) {
                if (predictions_[i] == bloomfilter::Status::Ok) {
                    false_positives++;
                }
            }
            float FPR = (float)false_positives / (float)X_val_neg.size();
            std::cout << "[INFO] FPR: " << FPR << std::endl;
            if (FPR > F + 0.01) {
                std::cerr << "[WARNING] FPR is greater than the expected FPR.\n";
                calibration_manager.get_burger_config().print();
                // exit(1);
            }
        }

        burger_config = calibration_manager.get_burger_config();
    }

    // Save CLBF
    float model_size_kb = 0;
    {
        clbf.save_model(clbf_model_path);
        model_size_kb = getBinFileSize(clbf_model_path);
    }

    // Save info.json
    {
        std::ofstream ofs(clbf_info_path);
        ofs << "{\n";
        ofs << "  \"model_type\": \"clbf\",\n";
        ofs << "  \"F\": " << F << ",\n";
        ofs << "  \"lambda\": " << lambda_ << ",\n";
        ofs << "  \"mu\": " << mu_ << ",\n";
        ofs << "  \"trained_xgboost_folder_path\": \"" << trained_xgboost_folder_path << "\",\n";
        ofs << "  \"model_construct_time_ms\": " << model_construct_time_ms << ",\n";
        ofs << "  \"prepare_time_ms\": " << prepare_time_ms << ",\n";
        ofs << "  \"calibration_time_ms\": " << calibration_time_ms << ",\n";
        ofs << "  \"configuration_time_ms\": " << configuration_time_ms << ",\n";
        ofs << "  \"add_keys_time_ms\": " << add_keys_time_ms << ",\n";
        ofs << "  \"model_size_kb\": " << model_size_kb << "\n";
        ofs << "}\n";
        ofs.close();
    }

    auto vector_to_json = [](const std::vector<float>& v, int precision = 6) {
        std::ostringstream s;
        s << "[";
        for (int i = 0; i < (int)v.size(); i++) {
            if (std::isnan(v[i])) {
                s << "null";
            } else if (std::isinf(v[i])) {
                s << (v[i] > 0 ? "Infinity" : "-Infinity");
            } else {
                s << std::fixed << std::setprecision(precision) << v[i];
            }
            if (i < (int)v.size() - 1) {
                s << ", ";
            }
        }
        s << "]";
        return s.str();
    };

    // Save config.json
    {
        std::ofstream ofs(clbf_config_path);
        ofs << "{\n";
        ofs << "  \"d\": " << burger_config.d << ",\n";
        ofs << "  \"k\": " << burger_config.k << ",\n";
        ofs << "  \"n\": " << burger_config.n << ",\n";
        ofs << "  \"b\": " << vector_to_json(burger_config.b) << ",\n";
        ofs << "  \"t\": " << vector_to_json(burger_config.t) << ",\n";
        ofs << "  \"f\": " << vector_to_json(burger_config.f) << ",\n";
        ofs << "  \"th_b\": " << vector_to_json(burger_config.th_b) << ",\n";
        ofs << "  \"th_f\": " << vector_to_json(burger_config.th_f) << ",\n";
        ofs << "  \"g_b\": " << vector_to_json(burger_config.g_b) << ",\n";
        ofs << "  \"g_t\": " << vector_to_json(burger_config.g_t) << ",\n";
        ofs << "  \"g_f\": " << vector_to_json(burger_config.g_f) << ",\n";
        // for memo
        ofs << "  \"h_b\": " << vector_to_json(burger_config.h_b) << ",\n";
        ofs << "  \"h_t\": " << vector_to_json(burger_config.h_t) << ",\n";
        ofs << "  \"h_f\": " << vector_to_json(burger_config.h_f) << ",\n";
        ofs << "  \"alpha\": " << burger_config.alpha << "\n";
        ofs << "}\n";
        ofs.close();
    }
}


int main(int argc, char* argv[]) {
    std::string all_pos_key_path;
    std::string all_pos_X_path;
    std::string X_val_key_path;
    std::string X_val_path;
    std::string y_val_path;
    std::string trained_xgboost_folder_path;
    std::string clbf_folder_path;
    float F = -1.0;
    float lambda_ = -1.0;
    float mu_ = -1.0;

    int opt;
    while ((opt = getopt(argc, argv, "p:q:k:x:y:t:o:f:l:m:")) != -1) {
        switch (opt) {
            case 'p':
                all_pos_key_path = optarg;
                break;
            case 'q':
                all_pos_X_path = optarg;
                break;
            case 'k':
                X_val_key_path = optarg;
                break;
            case 'x':
                X_val_path = optarg;
                break;
            case 'y':
                y_val_path = optarg;
                break;
            case 't':
                trained_xgboost_folder_path = optarg;
                break;
            case 'o':
                clbf_folder_path = optarg;
                break;
            case 'f':
                F = std::atof(optarg);
                break;
            case 'l':
                lambda_ = std::atof(optarg);
                break;
            case 'm':
                mu_ = std::atof(optarg);
                break;
            default:
                std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -q <all_pos_X_path> -k <X_val_key_path> -x <X_val_path> -y <y_val_path> -t <trained_xgboost_folder_path> -o <clbf_folder_path> -f <F> -l <lambda> -m <mu>\n";
                return 1;
        }
    }

    // Check if all required arguments are provided
    if (all_pos_key_path.empty() || all_pos_X_path.empty() || X_val_key_path.empty() || X_val_path.empty() || y_val_path.empty() || trained_xgboost_folder_path.empty() || clbf_folder_path.empty() || F < 0 || lambda_ < 0 || mu_ < 0) {
        std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -q <all_pos_X_path> -k <X_val_key_path> -x <X_val_path> -y <y_val_path> -t <trained_xgboost_folder_path> -o <clbf_folder_path> -f <F> -l <lambda> -m <mu>\n";
        return 1;
    }

    // If clbf_model_path already exists, skip constructing the clbf
    std::string clbf_model_path = clbf_folder_path + "/clbf_model.bin";
    if (std::filesystem::exists(clbf_model_path)) {
        std::cout << "CLBF model already exists at " << clbf_model_path << ". Skipping constructing the CLBF.\n";
        return 0;
    }

    // Ensure the processing code is executed
    std::vector<std::string> all_pos_key = load_key(all_pos_key_path);
    std::vector<std::vector<float>> all_pos_X = load_data(all_pos_X_path);
    std::vector<std::string> X_val_key = load_key(X_val_key_path);
    std::vector<std::vector<float>> X_val = load_data(X_val_path);
    std::vector<std::vector<float>> y_val = load_data(y_val_path);

    construct_clbf(all_pos_key, all_pos_X, X_val_key, X_val, y_val, trained_xgboost_folder_path, clbf_folder_path, F, lambda_, mu_);

    return 0;
}
