#ifndef XGBOOST_IO_H
#define XGBOOST_IO_H

#include "tree.h"
#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <filesystem>
#include <exception>
#include <iostream>

std::string get_tmp_path() {
    std::string tmp_path = "/tmp/xgboost_model.bin";
    if (std::filesystem::exists(tmp_path)) {
        for (int i = 0; i < 100; i++) {
            tmp_path = "/tmp/xgboost_model_" + std::to_string(i) + ".bin";
            if (!std::filesystem::exists(tmp_path)) {
                break;
            }
            if (i == 99) {
                throw std::runtime_error("Failed to create a temporary file.");
            }
        }
    }
    return tmp_path;
}

// save and load xgboost model
void save_xgboost_model(const BoosterHandle& booster, std::ofstream& ofs) {
    // Save the XGBoost model to a tmp file
    std::string tmp_path = get_tmp_path();

    int save_to_tmp_result = XGBoosterSaveModel(booster, tmp_path.c_str());
    if (save_to_tmp_result != 0) {
        throw std::runtime_error("Failed to save model to a temporary file.");
    }

    // Read the XGBoost model from the tmp file
    std::ifstream ifs(tmp_path, std::ios::binary);
    if (!ifs) {
        throw std::runtime_error("Failed to open temporary file for reading model.");
    }
    ifs.seekg(0, std::ios::end);
    std::streampos out_len = ifs.tellg();
    ifs.seekg(0, std::ios::beg);
    std::vector<char> model_data(out_len);
    ifs.read(model_data.data(), out_len);
    ifs.close();

    // Save the XGBoost model size
    bst_ulong out_len_ = static_cast<bst_ulong>(out_len);
    ofs.write(reinterpret_cast<char*>(&out_len_), sizeof(out_len_));

    // Save the XGBoost model
    ofs.write(model_data.data(), out_len);

    // Remove the temporary file
    if (std::remove(tmp_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }

    std::cout << "[INFO] XGBoost Model Size: " << out_len << " bytes" << std::endl;
}

BoosterHandle load_xgboost_model(std::ifstream& ifs) {
    // Read the XGBoost model size
    bst_ulong out_len;
    ifs.read(reinterpret_cast<char*>(&out_len), sizeof(out_len));
    if (!ifs) {
        throw std::runtime_error("Failed to read the model size from the file.");
    }

    // Read the XGBoost model data
    std::vector<char> model_data(out_len);
    ifs.read(model_data.data(), out_len);
    if (!ifs) {
        throw std::runtime_error("Failed to read the model data from the file.");
    }

    // Save the model data to a temporary file
    std::string tmp_path = get_tmp_path();
    std::ofstream ofs_tmp(tmp_path, std::ios::binary);
    if (!ofs_tmp) {
        throw std::runtime_error("Failed to open temporary file for writing model.");
    }
    ofs_tmp.write(model_data.data(), out_len);
    ofs_tmp.close();

    // Load the XGBoost model from the temporary file
    BoosterHandle booster;
    XGBoosterCreate(0, 0, &booster);
    int load_result = XGBoosterLoadModel(booster, tmp_path.c_str());
    if (load_result != 0) {
        throw std::runtime_error("Failed to load model from the temporary file.");
    }

    // Remove the temporary file
    if (std::remove(tmp_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }

    std::cout << "[INFO] XGBoost Model Size: " << out_len << " bytes" << std::endl;

    return booster;
}

float get_xgboost_model_size_kb(const BoosterHandle& booster) {
    // Save the XGBoost model to a tmp file
    std::string tmp_path = get_tmp_path();
    int save_to_tmp_result = XGBoosterSaveModel(booster, tmp_path.c_str());
    if (save_to_tmp_result != 0) {
        throw std::runtime_error("Failed to save model to a temporary file.");
    }
    // Read the XGBoost model from the tmp file
    std::ifstream ifs(tmp_path, std::ios::binary);
    if (!ifs) {
        throw std::runtime_error("Failed to open temporary file for reading model.");
    }
    ifs.seekg(0, std::ios::end);
    std::streampos out_len = ifs.tellg();
    ifs.seekg(0, std::ios::beg);
    std::vector<char> model_data(out_len);
    ifs.read(model_data.data(), out_len);
    ifs.close();
    if (std::remove(tmp_path.c_str()) != 0) {
        throw std::runtime_error("Failed to delete the temporary file.");
    }
    return static_cast<float>(out_len) / 1024;
}

#endif // XGBOOST_IO_H