#include "../../src/utils/utils.h"
#include "../../src/plbf/plbf.h"
#include "../../src/plbf/calibration_manager.h"
#include "../../src/plbf/burger_config.h"
#include "../utils/str2uint64_hash.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <unistd.h>

void construct_plbf(
    const std::vector<std::string>& all_pos_key_str,
    const std::vector<std::vector<float>>& all_pos_X,
    const std::vector<std::string>& X_val_key,
    const std::vector<std::vector<float>>& X_val,
    const std::vector<std::vector<float>>& y_val,
    std::string trained_xgboost_folder_path,
    std::string plbf_folder_path,
    float F
) {
    // If plbf_folder_path does not exist, create it
    if (!std::filesystem::exists(plbf_folder_path)) {
        std::filesystem::create_directories(plbf_folder_path);
    }

    std::string xgboost_model_path = trained_xgboost_folder_path + "/xgboost_model.bin";
    std::string xgboost_info_path = trained_xgboost_folder_path + "/info.json";
    std::string plbf_model_path = plbf_folder_path + "/plbf_model.bin";
    std::string plbf_info_path = plbf_folder_path + "/info.json";

    // Construct PLBF
    PLBF plbf;
    float prepare_time_ms = 0;
    float calibration_time_ms = 0;
    float configuration_time_ms = 0;
    float add_keys_time_ms = 0;
    float model_construct_time_ms = 0;
    {
        auto construction_start = std::chrono::high_resolution_clock::now();
        // prepare positive samples
        auto prepare_start = std::chrono::high_resolution_clock::now();
        std::vector<uint64_t> all_pos_key;
        for (int i = 0; i < (int)all_pos_key_str.size(); i++) {
            all_pos_key.push_back(str2uint64_hash(all_pos_key_str[i]));
        }
        std::vector<uint64_t> X_val_key_pos;
        std::vector<uint64_t> X_val_key_neg;
        std::vector<std::vector<float>> X_val_pos;
        std::vector<std::vector<float>> X_val_neg;
        for (int i = 0; i < (int)X_val.size(); i++) {
            if (y_val[i][0] == 1) {
                X_val_key_pos.push_back(str2uint64_hash(X_val_key[i]));
                X_val_pos.push_back(X_val[i]);
            } else {
                X_val_key_neg.push_back(str2uint64_hash(X_val_key[i]));
                X_val_neg.push_back(X_val[i]);
            }
        }
        auto prepare_end = std::chrono::high_resolution_clock::now();

        // Calibrate and add keys
        plbf.load_trained_xgboost_model(xgboost_model_path);
        Calibration_Manager calibration_manager(plbf);
        auto calibration_start = std::chrono::high_resolution_clock::now();
        calibration_manager.calibrate(X_val, y_val);
        auto calibration_end = std::chrono::high_resolution_clock::now();
        auto configure_start = std::chrono::high_resolution_clock::now();
        plbf = calibration_manager.configure_burger_assembly(all_pos_X.size(), F);
        auto configure_end = std::chrono::high_resolution_clock::now();

        // add keys
        std::cout << "[INFO] Adding keys to PLBF..." << std::endl;
        auto add_keys_start = std::chrono::high_resolution_clock::now();
        plbf.add_all(all_pos_X, all_pos_key);
        auto add_keys_end = std::chrono::high_resolution_clock::now();
        std::cout << "[INFO] Done adding keys to PLBF." << std::endl;

        auto construction_end = std::chrono::high_resolution_clock::now();
        model_construct_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(construction_end - construction_start).count();
        prepare_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(prepare_end - prepare_start).count();
        calibration_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(calibration_end - calibration_start).count();
        configuration_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(configure_end - configure_start).count();
        add_keys_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(add_keys_end - add_keys_start).count();

        // Check FNR == 0
        {
            std::vector<bloomfilter::Status> predictions_ = plbf.contains_all(X_val_pos, X_val_key_pos);
            int false_negatives = 0;
            for (int i = 0; i < (int)predictions_.size(); i++) {
                if (predictions_[i] == bloomfilter::Status::NotFound) {
                    false_negatives++;
                }
            }
            if (false_negatives > 0) {
                std::cerr << "Error: False negatives found in the positive samples.\n";
                exit(1);
            }
        }

        // Check FPR ~ F
        {
            std::vector<bloomfilter::Status> predictions_ = plbf.contains_all(X_val_neg, X_val_key_neg);
            int false_positives = 0;
            for (int i = 0; i < (int)predictions_.size(); i++) {
                if (predictions_[i] == bloomfilter::Status::Ok) {
                    false_positives++;
                }
            }
            float FPR = (float)false_positives / (float)X_val_neg.size();
            std::cout << "[INFO] FPR: " << FPR << std::endl;
            if (FPR > F + 0.01) {
                std::cerr << "[WARNING] FPR is greater than the expected FPR.\n";
                calibration_manager.get_burger_config().print();
                // exit(1);
            }
        }

        // burger_config = calibration_manager.get_burger_config();
    }

    // Save PLBF
    float model_size_kb = 0;
    {
        plbf.save_model(plbf_model_path);
        model_size_kb = getBinFileSize(plbf_model_path);
    }

    // Save info.json
    {
        std::ofstream ofs(plbf_info_path);
        ofs << "{\n";
        ofs << "  \"model_type\": \"plbf\",\n";
        ofs << "  \"F\": " << F << ",\n";
        ofs << "  \"trained_xgboost_folder_path\": \"" << trained_xgboost_folder_path << "\",\n";
        ofs << "  \"model_construct_time_ms\": " << model_construct_time_ms << ",\n";
        ofs << "  \"prepare_time_ms\": " << prepare_time_ms << ",\n";
        ofs << "  \"calibration_time_ms\": " << calibration_time_ms << ",\n";
        ofs << "  \"configuration_time_ms\": " << configuration_time_ms << ",\n";
        ofs << "  \"add_keys_time_ms\": " << add_keys_time_ms << ",\n";
        ofs << "  \"model_size_kb\": " << model_size_kb << "\n";
        ofs << "}\n";
        ofs.close();
    }
}


int main(int argc, char* argv[]) {
    std::string all_pos_key_path;
    std::string all_pos_X_path;
    std::string X_val_key_path;
    std::string X_val_path;
    std::string y_val_path;
    std::string trained_xgboost_folder_path;
    std::string plbf_folder_path;
    float F = -1.0;

    int opt;
    while ((opt = getopt(argc, argv, "p:q:k:x:y:t:o:f:")) != -1) {
        switch (opt) {
            case 'p':
                all_pos_key_path = optarg;
                break;
            case 'q':
                all_pos_X_path = optarg;
                break;
            case 'k':
                X_val_key_path = optarg;
                break;
            case 'x':
                X_val_path = optarg;
                break;
            case 'y':
                y_val_path = optarg;
                break;
            case 't':
                trained_xgboost_folder_path = optarg;
                break;
            case 'o':
                plbf_folder_path = optarg;
                break;
            case 'f':
                F = std::atof(optarg);
                break;
            default:
                std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -q <all_pos_X_path> -k <X_val_key_path> -x <X_val_path> -y <y_val_path> -t <trained_xgboost_folder_path> -o <plbf_folder_path> -f <F>\n";
                return 1;
        }
    }

    // Check if all required arguments are provided
    if (all_pos_key_path.empty() || all_pos_X_path.empty() || X_val_key_path.empty() || X_val_path.empty() || y_val_path.empty() || trained_xgboost_folder_path.empty() || plbf_folder_path.empty() || F == -1.0) {
        std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -q <all_pos_X_path> -k <X_val_key_path> -x <X_val_path> -y <y_val_path> -t <trained_xgboost_folder_path> -o <plbf_folder_path> -f <F>\n";
        return 1;
    }

    // If plbf_model_path already exists, skip constructing the plbf
    std::string plbf_model_path = plbf_folder_path + "/plbf_model.bin";
    if (std::filesystem::exists(plbf_model_path)) {
        std::cout << "PLBF model already exists at " << plbf_model_path << ". Skipping constructing the PLBF.\n";
        return 0;
    }

    // Ensure the processing code is executed
    std::vector<std::string> all_pos_key = load_key(all_pos_key_path);
    std::vector<std::vector<float>> all_pos_X = load_data(all_pos_X_path);
    std::vector<std::string> X_val_key = load_key(X_val_key_path);
    std::vector<std::vector<float>> X_val = load_data(X_val_path);
    std::vector<std::vector<float>> y_val = load_data(y_val_path);

    construct_plbf(all_pos_key, all_pos_X, X_val_key, X_val, y_val, trained_xgboost_folder_path, plbf_folder_path, F);

    return 0;
}
