#ifndef SandwichedLBF_H
#define SandwichedLBF_H

#include "burger_config.h"
#include "../utils/tree.h"
#include "../utils/bloom_io.h"
#include "../utils/xgboost_io.h"
#include "../utils/reduce_xgboost_model.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <stdexcept>
#include <sstream>
#include <cstdio>
#include <nlohmann/json.hpp>  // Include a JSON library like nlohmann/json

#include "bloom.h"

class SandwichedLBF {
public:
    SandwichedLBF(){};
    void add(const std::vector<float> &x, const uint64_t& key);    // Not implemented
    void add_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys);
    bloomfilter::Status contains(const std::vector<float> &x, const uint64_t& key);   // Not implemented
    std::vector<bloomfilter::Status> contains_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys);
    void load_trained_xgboost_model(const std::string &path);
    std::pair<std::vector<std::vector<float>>, float> calibrate(const std::vector<std::vector<float>> &X);
    SandwichedLBF configure_burger_assembly(long int n, const BurgerConfig &config);
    bool is_ml_model_trained() const { return this->ml_model_trained; }
    void train_xgboost(const std::vector<std::vector<float>> &X, const std::vector<std::vector<float>> &y, int num_rounds); // TODO
    void save_model(const std::string &path);
    void load_model(const std::string &path);
    float xgboost_model_size_kb;

private:
    // Flags
    bool ml_model_trained = false;
    bool configured = false;
    // XGBoost model
    std::vector<Tree> trees;
    float base_logit;
    BoosterHandle booster;
    // Bloom Filters and parameters
    size_t d;
    size_t k;
    float th_f;
    std::vector<bloomfilter::BloomFilter<uint64_t, false>> bf_t;
    std::vector<bloomfilter::BloomFilter<uint64_t, false>> bf_f;

    void prepare_tree_data_for_inference();
};

void SandwichedLBF::add(const std::vector<float> &x, const uint64_t& key) {
    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }
    throw std::runtime_error("Not implemented.");
}

void SandwichedLBF::add_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys) {
    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }

    size_t num_rows = X.size();
    for (std::size_t i = 0; i < num_rows; ++i) {
        bf_t[0].Add(keys[i]);
        const auto &row = X[i];
        float weight_sum = this->base_logit;
        for (std::size_t j = 0; j < this->d; ++j) {
            weight_sum += this->trees[j].get_leaf_weight(row);
        }
        if (weight_sum < this->th_f) {
            bf_f[0].Add(keys[i]);
        }
    }
}

bloomfilter::Status SandwichedLBF::contains(const std::vector<float> &x, const uint64_t& key) {
    throw std::runtime_error("Not implemented.");
}

std::vector<bloomfilter::Status> SandwichedLBF::contains_all(const std::vector<std::vector<float>> &X, const std::vector<uint64_t> &keys) {
    if (!this->ml_model_trained) {
        throw std::runtime_error("ML model not trained. Please train the model first before adding.");
    }
    if (!this->configured) {
        throw std::runtime_error("Burger assembly not configured. Please configure the burger assembly first before adding.");
    }

    size_t num_rows = X.size();
    std::vector<bloomfilter::Status> predictions(num_rows, bloomfilter::Status::Ok);
    for (std::size_t i = 0; i < num_rows; ++i) {
        if (bf_t[0].Contain(keys[i]) == bloomfilter::Status::NotFound) {
            predictions[i] = bloomfilter::Status::NotFound;
        } else {
            const auto &row = X[i];
            float weight_sum = this->base_logit;
            for (std::size_t j = 0; j < this->d; ++j) {
                weight_sum += this->trees[j].get_leaf_weight(row);
            }
            if (weight_sum < this->th_f && bf_f[0].Contain(keys[i]) == bloomfilter::Status::NotFound) {
                predictions[i] = bloomfilter::Status::NotFound;
            }
        }
    }
    return predictions;
}

void SandwichedLBF::prepare_tree_data_for_inference() {
    // Save the model into a JSON file
    std::string model_json_path = "/tmp/xgboost_model.json";
    XGBoosterSaveModel(this->booster, model_json_path.c_str());

    // Load the JSON file
    nlohmann::json xgboost_dict;
    std::ifstream(model_json_path) >> xgboost_dict;

    // Remove the temporary file
    if (std::remove(model_json_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }

    // Prepare the trees and base_logit
    this->trees.clear();
    for (const auto &tree_dict : xgboost_dict["learner"]["gradient_booster"]["model"]["trees"]) {
        Tree t(
            tree_dict["base_weights"].get<std::vector<float>>(),
            tree_dict["left_children"].get<std::vector<int>>(),
            tree_dict["right_children"].get<std::vector<int>>(),
            tree_dict["split_conditions"].get<std::vector<float>>(),
            tree_dict["split_indices"].get<std::vector<int>>()
        );
        this->trees.push_back(t);
    }
    std::string base_score_str = xgboost_dict["learner"]["learner_model_param"]["base_score"];
    float base_score = std::stof(base_score_str);
    this->base_logit = std::log(base_score / (1 - base_score));
}

void SandwichedLBF::train_xgboost(const std::vector<std::vector<float>> &X_train, const std::vector<std::vector<float>> &y_train, int num_rounds) {
    return;
    /*
    int num_rows = X_train.size();
    int num_cols = X_train[0].size();
    std::vector<float> X_train_flat;
    std::vector<float> y_train_flat;
    for (const auto &row : X_train) {
        X_train_flat.insert(X_train_flat.end(), row.begin(), row.end());
    }
    for (const auto &row : y_train) {
        y_train_flat.insert(y_train_flat.end(), row.begin(), row.end());
    }
    DMatrixHandle dtrain;
    XGDMatrixCreateFromMat(X_train_flat.data(), num_rows, num_cols, -1, &dtrain);
    XGDMatrixSetFloatInfo(dtrain, "label", y_train_flat.data(), num_rows);

    BoosterHandle booster;
    XGBoosterCreate(&dtrain, 1, &booster);
    XGBoosterSetParam(booster, "objective", "binary:logistic");
    XGBoosterSetParam(booster, "max_depth", "4");
    XGBoosterSetParam(booster, "eta", "0.1");
    XGBoosterSetParam(booster, "verbosity", "0");
    for (int i = 0; i < num_rounds; ++i) {
        XGBoosterUpdateOneIter(booster, i, dtrain);
    }

    this->ml_model_trained = true;
    XGDMatrixFree(dtrain);
    this->prepare_tree_data_for_inference();
    */
}

void SandwichedLBF::load_trained_xgboost_model(const std::string &path) {
    std::ifstream ifs(path, std::ios::binary);
    if (!ifs) {
        throw std::runtime_error("Failed to open file for loading XGBoost model.");
    }

    // Read the XGBoost model size
    bst_ulong out_len;
    ifs.read(reinterpret_cast<char*>(&out_len), sizeof(out_len));
    if (!ifs) {
        throw std::runtime_error("Failed to read the model size from the file.");
    }

    // Read the XGBoost model data
    std::vector<char> model_data(out_len);
    ifs.read(model_data.data(), out_len);
    if (!ifs) {
        throw std::runtime_error("Failed to read the model data from the file.");
    }

    // Save the model data to a temporary file
    std::string tmp_path = "/tmp/xgboost_model.bin";
    std::ofstream ofs(tmp_path, std::ios::binary);
    if (!ofs) {
        throw std::runtime_error("Failed to open temporary file for writing model.");
    }
    ofs.write(model_data.data(), out_len);

    // Load the XGBoost model from the temporary file
    BoosterHandle tmp_booster;
    XGBoosterCreate(0, 0, &tmp_booster);
    int load_result = XGBoosterLoadModel(tmp_booster, tmp_path.c_str());
    if (load_result != 0) {
        throw std::runtime_error("Failed to load model from the temporary file.");
    }
    // Remove the temporary file
    if (std::remove(tmp_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }
    this->booster = tmp_booster;

    this->prepare_tree_data_for_inference();

    this->ml_model_trained = true;
    this->xgboost_model_size_kb = get_xgboost_model_size_kb(this->booster);
}

std::pair<std::vector<std::vector<float>>, float> SandwichedLBF::calibrate(const std::vector<std::vector<float>> &X) {
    auto start = std::chrono::high_resolution_clock::now();
    size_t num_rows = X.size();
    size_t D = this->trees.size();
    std::vector<std::vector<float>> calibration_data(num_rows, std::vector<float>(D, 0.0f));
    for (std::size_t i = 0; i < num_rows; ++i) {
        const auto &row = X[i];
        float weight_sum = this->base_logit;
        for (std::size_t j = 0; j < D; j++) {
            weight_sum += this->trees[j].get_leaf_weight(row);
            calibration_data[i][j] = weight_sum;
        }
    }
    auto end = std::chrono::high_resolution_clock::now();
    float calibration_time_us = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() * 1000; 
    return std::make_pair(calibration_data, calibration_time_us);
}

SandwichedLBF SandwichedLBF::configure_burger_assembly(long int n, const BurgerConfig &config) {
    SandwichedLBF sandwichedlbf;
    // Reduce the XGBoost model and prepare the tree data for inference
    sandwichedlbf.booster = reduce_xgboost_model(this->booster, config.d);
    sandwichedlbf.prepare_tree_data_for_inference();
    // Configure the burger assembly
    sandwichedlbf.d = config.d;
    sandwichedlbf.k = config.k;
    sandwichedlbf.th_f = config.th_f;
    // Configure the Pre-filter
    sandwichedlbf.bf_t.clear();
    if (config.t >= 1 - 1e-9) {
        bloomfilter::BloomFilter<uint64_t, false> bf(1, 1);
        sandwichedlbf.bf_t.push_back(bf);
    } else {
        int bits_per_item = std::max(1, (int)(-std::log(config.t) / std::log(2) / std::log(2) + 0.5));
        bloomfilter::BloomFilter<uint64_t, false> bf(n * 1.0 + 1, bits_per_item);
        sandwichedlbf.bf_t.push_back(bf);
    }
    // Configure the Final-filter
    sandwichedlbf.bf_f.clear();
    if (config.f >= 1 - 1e-9) {
        bloomfilter::BloomFilter<uint64_t, false> bf(1, 1);
        sandwichedlbf.bf_f.push_back(bf);
    } else {
        int bits_per_item = std::max(1, (int)(-std::log(config.f) / std::log(2) / std::log(2) + 0.5));
        bloomfilter::BloomFilter<uint64_t, false> bf(n * config.g + 1, bits_per_item);
        sandwichedlbf.bf_f.push_back(bf);
    }
    sandwichedlbf.ml_model_trained = true;
    sandwichedlbf.configured = true;
    return sandwichedlbf;
}

void SandwichedLBF::save_model(const std::string &path) {
    std::ofstream ofs(path, std::ios::binary);
    if (!ofs) {
        throw std::runtime_error("Failed to open file for saving model.");
    }

    // Save the XGBoost model
    {
        std::cout << "[INFO] Saving XGBoost model..." << std::endl;
        save_xgboost_model(this->booster, ofs);
    }

    // Save the other attributes
    {
        ofs.write(reinterpret_cast<char*>(&ml_model_trained), sizeof(ml_model_trained));
        ofs.write(reinterpret_cast<char*>(&configured), sizeof(configured));
    }

    // Save the BloomFilters
    {
        // Save the BloomFilters
        // Save d, k, th_f
        ofs.write(reinterpret_cast<char*>(&this->d), sizeof(this->d));
        ofs.write(reinterpret_cast<char*>(&this->k), sizeof(this->k));
        ofs.write(reinterpret_cast<char*>(&this->th_f), sizeof(this->th_f));
        // Save bf_t
        save_bloom_filter(this->bf_t[0], ofs);
        // Save bf_f
        save_bloom_filter(this->bf_f[0], ofs);
    }
    ofs.close();
    std::cout << "[INFO] Model saved to " << path << std::endl;
}

void SandwichedLBF::load_model(const std::string &path) {
    if (!std::filesystem::exists(path)) {
        throw std::runtime_error("Model file does not exist: " + path); 
    }

    std::ifstream ifs(path, std::ios::binary);
    if (!ifs) {
        throw std::runtime_error("Failed to open file for loading model: " + path);
    }

    // Load the XGBoost model
    {
        std::cout << "[INFO] Loading XGBoost model..." << std::endl;
        this->booster = load_xgboost_model(ifs);
        // Load the tree data
        this->prepare_tree_data_for_inference();
    }

    // Load the other attributes
    {
        ifs.read(reinterpret_cast<char*>(&ml_model_trained), sizeof(ml_model_trained));
        ifs.read(reinterpret_cast<char*>(&configured), sizeof(configured));
    }

    // Load the BloomFilters
    {
        // Load d, k, th_f
        ifs.read(reinterpret_cast<char*>(&this->d), sizeof(this->d));
        ifs.read(reinterpret_cast<char*>(&this->k), sizeof(this->k));
        ifs.read(reinterpret_cast<char*>(&this->th_f), sizeof(this->th_f));
        // Load bf_t
        this->bf_t.clear();
        this->bf_t.push_back(load_bloom_filter(ifs));
        // Load bf_f
        this->bf_f.clear();
        this->bf_f.push_back(load_bloom_filter(ifs));
    }
    ifs.close();
    std::cout << "[INFO] Model loaded from " << path << std::endl;
}

#endif // SandwichedLBF_H
