#ifndef CLBF_H
#define CLBF_H

#include "burger_config.h"
#include "../utils/tree.h"
#include "../utils/bloom_io.h"
#include "../utils/xgboost_io.h"
#include "../utils/reduce_xgboost_model.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <stdexcept>
#include <sstream>
#include <cstdio>
#include <nlohmann/json.hpp>  // Include a JSON library like nlohmann/json

#include "bloom.h"

class CLBF {
public:
    CLBF(){};
    void add(const std::vector<float> &x, const uint64_t& key);    // Not implemented
    void add_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys);
    bloomfilter::Status contains(const std::vector<float> &x, const uint64_t& key);   // Not implemented
    std::vector<bloomfilter::Status> contains_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys);
    void load_trained_xgboost_model(const std::string &path);
    std::pair<std::vector<std::vector<float>>, float> calibrate(const std::vector<std::vector<float>> &X);
    CLBF configure_burger_assembly(long int n, const BurgerConfig &config);
    bool is_ml_model_trained() const { return this->ml_model_trained; }
    void train_xgboost(const std::vector<std::vector<float>> &X, const std::vector<std::vector<float>> &y, int num_rounds); // TODO
    void save_model(const std::string &path);
    void load_model(const std::string &path);
    float xgboost_model_size_kb;

    // for experiment
    std::vector<std::pair<std::string, int>> get_last_node(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys);
    size_t get_d() const { return this->d; }
    size_t get_k() const { return this->k; }

private:
    // Flags
    bool ml_model_trained = false;
    bool configured = false;
    // XGBoost model
    std::vector<Tree> trees;
    float base_logit;
    BoosterHandle booster;
    // Bloom Filters and parameters
    size_t d;
    size_t k;
    std::vector<float> th_b;
    std::vector<float> th_f;
    std::vector<bloomfilter::BloomFilter<uint64_t, false>> bf_t;
    std::vector<bloomfilter::BloomFilter<uint64_t, false>> bf_b;
    std::vector<bloomfilter::BloomFilter<uint64_t, false>> bf_f;

    void prepare_tree_data_for_inference();
};

void CLBF::add(const std::vector<float> &x, const uint64_t& key) {
    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }
    throw std::runtime_error("Not implemented.");
}

void CLBF::add_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys) {
    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }

    size_t num_rows = X.size();

    if (this->d == 0) {
        for (std::size_t i = 0; i < num_rows; ++i) {
            this->bf_t[0].Add(keys[i]);
        }
        return;
    }

    for (std::size_t i = 0; i < num_rows; ++i) {
        const auto &row = X[i];
        float weight_sum = this->base_logit;
        for (std::size_t j = 0; j < this->d; ++j) {
            this->bf_t[j].Add(keys[i]);
            weight_sum += this->trees[j].get_leaf_weight(row);
            if (j == this->d - 1) {
                int final_bf_idx = std::lower_bound(this->th_f.begin(), this->th_f.end(), weight_sum) - this->th_f.begin();
                this->bf_f[final_bf_idx].Add(keys[i]);
            } else {
                if (this->th_b[j] < weight_sum) {
                    this->bf_b[j].Add(keys[i]);
                    break;
                }
            }
        }
    }
}

bloomfilter::Status CLBF::contains(const std::vector<float> &x, const uint64_t& key) {
    throw std::runtime_error("Not implemented.");
}

std::vector<bloomfilter::Status> CLBF::contains_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys) {
    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }

    size_t num_rows = X.size();
    std::vector<bloomfilter::Status> predictions(num_rows, bloomfilter::Status::Ok);

    if (this->d == 0) {
        for (std::size_t i = 0; i < num_rows; ++i) {
            predictions[i] = this->bf_t[0].Contain(keys[i]);
        }
        return predictions;
    }

    for (std::size_t i = 0; i < num_rows; ++i) {
        const auto &row = X[i];
        float weight_sum = this->base_logit;
        for (std::size_t j = 0; j < this->d; ++j) {
            if (this->bf_t[j].Contain(keys[i]) == bloomfilter::Status::NotFound) {
                predictions[i] = bloomfilter::Status::NotFound;
                break;
            }
            weight_sum += this->trees[j].get_leaf_weight(row);
            if (j == this->d - 1) {
                int final_bf_idx = std::lower_bound(this->th_f.begin(), this->th_f.end(), weight_sum) - this->th_f.begin();
                predictions[i] = this->bf_f[final_bf_idx].Contain(keys[i]);
                break;
            } else {
                if (this->th_b[j] < weight_sum) {
                    predictions[i] = this->bf_b[j].Contain(keys[i]);
                    break;
                }
            }
        }
    }
    return predictions;
}

std::vector<std::pair<std::string, int>> CLBF::get_last_node(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys) {
    /*
    Get the last node for each sample in X.
    The last node is represented as a pair of (node_type, node_index).
    - node_type: "bf_t", "bf_b", "bf_f"
    ex) ("bf_t", 0) means the last node is in bf_t[0].
    */

    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }

    size_t num_rows = X.size();
    std::vector<std::pair<std::string, int>> last_nodes(num_rows, std::pair<std::string, int>("", -1));

    if (this->d == 0) {
        for (std::size_t i = 0; i < num_rows; ++i) {
            last_nodes[i] = std::pair<std::string, int>("bf_t", 0);
        }
        return last_nodes;
    }

    for (std::size_t i = 0; i < num_rows; ++i) {
        const auto &row = X[i];
        float weight_sum = this->base_logit;
        for (std::size_t j = 0; j < this->d; ++j) {
            if (this->bf_t[j].Contain(keys[i]) == bloomfilter::Status::NotFound) {
                last_nodes[i] = std::pair<std::string, int>("bf_t", j);
                break;
            }
            weight_sum += this->trees[j].get_leaf_weight(row);
            if (j == this->d - 1) {
                int final_bf_idx = std::lower_bound(this->th_f.begin(), this->th_f.end(), weight_sum) - this->th_f.begin();
                last_nodes[i] = std::pair<std::string, int>("bf_f", final_bf_idx);
                break;
            } else {
                if (this->th_b[j] < weight_sum) {
                    last_nodes[i] = std::pair<std::string, int>("bf_b", j);
                    break;
                }
            }
        }
    }
    return last_nodes;
}

void CLBF::prepare_tree_data_for_inference() {
    // Save the model into a JSON file
    std::string model_json_path = "/tmp/xgboost_model.json";
    XGBoosterSaveModel(this->booster, model_json_path.c_str());

    // Load the JSON file
    nlohmann::json xgboost_dict;
    std::ifstream(model_json_path) >> xgboost_dict;

    // Remove the temporary file
    if (std::remove(model_json_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }

    // Prepare the trees and base_logit
    this->trees.clear();
    for (const auto &tree_dict : xgboost_dict["learner"]["gradient_booster"]["model"]["trees"]) {
        Tree t(
            tree_dict["base_weights"].get<std::vector<float>>(),
            tree_dict["left_children"].get<std::vector<int>>(),
            tree_dict["right_children"].get<std::vector<int>>(),
            tree_dict["split_conditions"].get<std::vector<float>>(),
            tree_dict["split_indices"].get<std::vector<int>>()
        );
        this->trees.push_back(t);
    }
    std::string base_score_str = xgboost_dict["learner"]["learner_model_param"]["base_score"];
    float base_score = std::stof(base_score_str);
    this->base_logit = std::log(base_score / (1 - base_score));
}

void CLBF::train_xgboost(const std::vector<std::vector<float>> &X_train, const std::vector<std::vector<float>> &y_train, int num_rounds) {
    return;
    /*
    int num_rows = X_train.size();
    int num_cols = X_train[0].size();
    std::vector<float> X_train_flat;
    std::vector<float> y_train_flat;
    for (const auto &row : X_train) {
        X_train_flat.insert(X_train_flat.end(), row.begin(), row.end());
    }
    for (const auto &row : y_train) {
        y_train_flat.insert(y_train_flat.end(), row.begin(), row.end());
    }
    DMatrixHandle dtrain;
    XGDMatrixCreateFromMat(X_train_flat.data(), num_rows, num_cols, -1, &dtrain);
    XGDMatrixSetFloatInfo(dtrain, "label", y_train_flat.data(), num_rows);

    BoosterHandle booster;
    XGBoosterCreate(&dtrain, 1, &booster);
    XGBoosterSetParam(booster, "objective", "binary:logistic");
    XGBoosterSetParam(booster, "max_depth", "4");
    XGBoosterSetParam(booster, "eta", "0.1");
    XGBoosterSetParam(booster, "verbosity", "0");
    for (int i = 0; i < num_rounds; ++i) {
        XGBoosterUpdateOneIter(booster, i, dtrain);
    }

    this->ml_model_trained = true;
    XGDMatrixFree(dtrain);
    this->prepare_tree_data_for_inference();
    */
}

void CLBF::load_trained_xgboost_model(const std::string &path) {
    std::ifstream ifs(path, std::ios::binary);
    if (!ifs) {
        throw std::runtime_error("Failed to open file for loading XGBoost model.");
    }

    // Read the XGBoost model size
    bst_ulong out_len;
    ifs.read(reinterpret_cast<char*>(&out_len), sizeof(out_len));
    if (!ifs) {
        throw std::runtime_error("Failed to read the model size from the file.");
    }

    // Read the XGBoost model data
    std::vector<char> model_data(out_len);
    ifs.read(model_data.data(), out_len);
    if (!ifs) {
        throw std::runtime_error("Failed to read the model data from the file.");
    }

    // Save the model data to a temporary file
    std::string tmp_path = "/tmp/xgboost_model.bin";
    std::ofstream ofs(tmp_path, std::ios::binary);
    if (!ofs) {
        throw std::runtime_error("Failed to open temporary file for writing model.");
    }
    ofs.write(model_data.data(), out_len);

    // Load the XGBoost model from the temporary file
    BoosterHandle tmp_booster;
    XGBoosterCreate(0, 0, &tmp_booster);
    int load_result = XGBoosterLoadModel(tmp_booster, tmp_path.c_str());
    if (load_result != 0) {
        throw std::runtime_error("Failed to load model from the temporary file.");
    }
    // Remove the temporary file
    if (std::remove(tmp_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }
    this->booster = tmp_booster;

    this->prepare_tree_data_for_inference();

    this->ml_model_trained = true;
    this->xgboost_model_size_kb = get_xgboost_model_size_kb(this->booster);
}

std::pair<std::vector<std::vector<float>>, float> CLBF::calibrate(const std::vector<std::vector<float>> &X) {
    auto start = std::chrono::high_resolution_clock::now();
    size_t num_rows = X.size();
    size_t D = this->trees.size();
    std::vector<std::vector<float>> calibration_data(num_rows, std::vector<float>(D, 0.0f));
    for (std::size_t i = 0; i < num_rows; ++i) {
        const auto &row = X[i];
        float weight_sum = this->base_logit;
        for (std::size_t j = 0; j < D; j++) {
            weight_sum += this->trees[j].get_leaf_weight(row);
            calibration_data[i][j] = weight_sum;
        }
    }
    auto end = std::chrono::high_resolution_clock::now();
    float calibration_time_us = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() * 1000; 
    return std::make_pair(calibration_data, calibration_time_us);
}

CLBF CLBF::configure_burger_assembly(long int n, const BurgerConfig &config) {
    CLBF clbf;

    if (config.d == 0) {
        clbf.d = 0;
        clbf.k = 0;
        clbf.ml_model_trained = true;
        clbf.configured = true;
        int bits_per_item = std::max(1, (int)(-std::log(config.t[0]) / std::log(2) / std::log(2) + 0.5));
        clbf.bf_t.push_back(bloomfilter::BloomFilter<uint64_t, false>(n + 1, bits_per_item));
        return clbf;
    }

    // Reduce the XGBoost model and prepare the tree data for inference
    clbf.booster = reduce_xgboost_model(this->booster, config.d);
    clbf.prepare_tree_data_for_inference();
    // Configure the burger assembly
    clbf.d = config.d;
    clbf.k = config.k;
    clbf.th_b = config.th_b;
    clbf.th_f = config.th_f;
    for (size_t i = 0; i < clbf.d; ++i) {
        if (config.t[i] >= 1 - 1e-9) {
            bloomfilter::BloomFilter<uint64_t, false> bf(1, 1);
            clbf.bf_t.push_back(bf);
            continue;
        }
        int bits_per_item = std::max(1, (int)(-std::log(config.t[i]) / std::log(2) / std::log(2) + 0.5));
        bloomfilter::BloomFilter<uint64_t, false> bf(n * config.g_t[i] + 1, bits_per_item);
        clbf.bf_t.push_back(bf);
    }
    for (size_t i = 0; i < clbf.d - 1; ++i) {
        if (config.b[i] >= 1 - 1e-9) {
            bloomfilter::BloomFilter<uint64_t, false> bf(1, 1);
            clbf.bf_b.push_back(bf);
            continue;
        }
        int bits_per_item = std::max(1, (int)(-std::log(config.b[i]) / std::log(2) / std::log(2) + 0.5));
        bloomfilter::BloomFilter<uint64_t, false> bf(n * config.g_b[i] + 1, bits_per_item);
        clbf.bf_b.push_back(bf);
    }
    for (size_t i = 0; i < clbf.k; ++i) {
        if (config.f[i] >= 1 - 1e-9) {
            bloomfilter::BloomFilter<uint64_t, false> bf(1, 1);
            clbf.bf_f.push_back(bf);
            continue;
        }
        int bits_per_item = std::max(1, (int)(-std::log(config.f[i]) / std::log(2) / std::log(2) + 0.5));
        bloomfilter::BloomFilter<uint64_t, false> bf(n * config.g_f[i] + 1, bits_per_item);
        clbf.bf_f.push_back(bf);
    }
    clbf.ml_model_trained = true;
    clbf.configured = true;
    return clbf;
}

void CLBF::save_model(const std::string &path) {
    std::ofstream ofs(path, std::ios::binary);
    if (!ofs) {
        throw std::runtime_error("Failed to open file for saving model.");
    }

    // Save d, k
    {
        ofs.write(reinterpret_cast<char*>(&this->d), sizeof(this->d));
        ofs.write(reinterpret_cast<char*>(&this->k), sizeof(this->k));
    }

    // Save the XGBoost model
    if (this->d > 0) {
        std::cout << "[INFO] Saving XGBoost model..." << std::endl;
        save_xgboost_model(this->booster, ofs);
    }

    // Save the other attributes
    {
        ofs.write(reinterpret_cast<char*>(&ml_model_trained), sizeof(ml_model_trained));
        ofs.write(reinterpret_cast<char*>(&configured), sizeof(configured));
    }

    // Save the BloomFilters
    {
        // Save the BloomFilters
        // Save th_b, th_f
        int th_b_size = this->th_b.size();
        ofs.write(reinterpret_cast<char*>(&th_b_size), sizeof(th_b_size));
        ofs.write(reinterpret_cast<char*>(this->th_b.data()), th_b_size * sizeof(float));
        int th_f_size = this->th_f.size();
        ofs.write(reinterpret_cast<char*>(&th_f_size), sizeof(th_f_size));
        ofs.write(reinterpret_cast<char*>(this->th_f.data()), th_f_size * sizeof(float));
        // Save bf_t
        int bf_t_size = this->bf_t.size();
        ofs.write(reinterpret_cast<char*>(&bf_t_size), sizeof(bf_t_size));
        for (int i = 0; i < bf_t_size; ++i) {
            save_bloom_filter(this->bf_t[i], ofs);
        }
        // Save bf_b
        int bf_b_size = this->bf_b.size();
        ofs.write(reinterpret_cast<char*>(&bf_b_size), sizeof(bf_b_size));
        for (int i = 0; i < bf_b_size; ++i) {
            save_bloom_filter(this->bf_b[i], ofs);
        }
        // Save bf_f
        int bf_f_size = this->bf_f.size();
        ofs.write(reinterpret_cast<char*>(&bf_f_size), sizeof(bf_f_size));
        for (int i = 0; i < bf_f_size; ++i) {
            save_bloom_filter(this->bf_f[i], ofs);
        }
    }
    ofs.close();
    std::cout << "[INFO] Model saved to " << path << std::endl;
}

void CLBF::load_model(const std::string &path) {
    if (!std::filesystem::exists(path)) {
        throw std::runtime_error("Model file does not exist: " + path); 
    }

    std::ifstream ifs(path, std::ios::binary);
    if (!ifs) {
        throw std::runtime_error("Failed to open file for loading model: " + path);
    }

    // Load d, k
    {
        ifs.read(reinterpret_cast<char*>(&this->d), sizeof(this->d));
        ifs.read(reinterpret_cast<char*>(&this->k), sizeof(this->k));
    }

    // Load the XGBoost model
    if (this->d > 0) {
        std::cout << "[INFO] Loading XGBoost model..." << std::endl;
        this->booster = load_xgboost_model(ifs);
        // Load the tree data
        this->prepare_tree_data_for_inference();
    }

    // Load the other attributes
    {
        ifs.read(reinterpret_cast<char*>(&ml_model_trained), sizeof(ml_model_trained));
        ifs.read(reinterpret_cast<char*>(&configured), sizeof(configured));
    }

    // Load the BloomFilters
    {
        // Load th_b, th_f
        int th_b_size;
        ifs.read(reinterpret_cast<char*>(&th_b_size), sizeof(th_b_size));
        this->th_b.resize(th_b_size);
        ifs.read(reinterpret_cast<char*>(this->th_b.data()), th_b_size * sizeof(float));
        int th_f_size;
        ifs.read(reinterpret_cast<char*>(&th_f_size), sizeof(th_f_size));
        this->th_f.resize(th_f_size);
        ifs.read(reinterpret_cast<char*>(this->th_f.data()), th_f_size * sizeof(float));
        // Load bf_t
        int bf_t_size;
        ifs.read(reinterpret_cast<char*>(&bf_t_size), sizeof(bf_t_size));
        this->bf_t.clear();
        for (int i = 0; i < bf_t_size; ++i) {
            this->bf_t.push_back(load_bloom_filter(ifs));
        }
        // Load bf_b
        int bf_b_size;
        ifs.read(reinterpret_cast<char*>(&bf_b_size), sizeof(bf_b_size));
        this->bf_b.clear();
        for (int i = 0; i < bf_b_size; ++i) {
            this->bf_b.push_back(load_bloom_filter(ifs));
        }
        // Load bf_f
        int bf_f_size;
        ifs.read(reinterpret_cast<char*>(&bf_f_size), sizeof(bf_f_size));
        this->bf_f.clear();
        for (int i = 0; i < bf_f_size; ++i) {
            this->bf_f.push_back(load_bloom_filter(ifs));
        }
    }
    ifs.close();
    std::cout << "[INFO] Model loaded from " << path << std::endl;
}

#endif // CLBF_H
