#include <xgboost/c_api.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <filesystem>
#include <unistd.h>

#include "../../src/utils/utils.h"

void train_xgboost(
    const std::vector<std::vector<float>>& X_train,
    const std::vector<std::vector<float>>& y_train,
    std::string xgboost_folder_path,
    int max_depth, int num_boost_round, float eta
) {
    // If xgboost_folder_path does not exist, create it
    if (!std::filesystem::exists(xgboost_folder_path)) {
        std::filesystem::create_directories(xgboost_folder_path);
    }

    std::string model_path = xgboost_folder_path + "/xgboost_model.bin";
    std::string info_json_path = xgboost_folder_path + "/info.json";

    std::cout << "Training XGBoost model...\n";

    // Train the XGBoost model
    BoosterHandle booster;
    float model_train_time_ms = 0;
    {
        // Prepare DMatrix
        int num_rows = X_train.size();
        int num_cols = X_train[0].size();
        std::vector<float> X_train_flat;
        std::vector<float> y_train_flat;
        for (const auto &row : X_train) {
            X_train_flat.insert(X_train_flat.end(), row.begin(), row.end());
        }
        for (const auto &row : y_train) {
            y_train_flat.insert(y_train_flat.end(), row.begin(), row.end());
        }
        DMatrixHandle dtrain;
        XGDMatrixCreateFromMat(X_train_flat.data(), num_rows, num_cols, -1, &dtrain);
        XGDMatrixSetFloatInfo(dtrain, "label", y_train_flat.data(), num_rows);

        // Set parameters
        std::string max_depth_str = std::to_string(max_depth);
        std::string eta_str = std::to_string(eta);
        std::string seed_str = std::to_string(42);
        XGBoosterCreate(&dtrain, 1, &booster);
        XGBoosterSetParam(booster, "objective", "binary:logistic");
        XGBoosterSetParam(booster, "max_depth", max_depth_str.c_str());
        XGBoosterSetParam(booster, "eta", eta_str.c_str());
        XGBoosterSetParam(booster, "verbosity", "0");
        XGBoosterSetParam(booster, "seed", seed_str.c_str());

        // Train
        auto start = std::chrono::high_resolution_clock::now();
        for (int i = 0; i < num_boost_round; i++) {
            XGBoosterUpdateOneIter(booster, i, dtrain);
        }
        auto end = std::chrono::high_resolution_clock::now();
        model_train_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

        // Free DMatrix
        XGDMatrixFree(dtrain);
    }

    // Save the model
    float model_size_kb = 0;
    {
        // Save the XGBoost model to a tmp file
        std::string tmp_path = "/tmp/xgboost_model.bin";
        int save_to_tmp_result = XGBoosterSaveModel(booster, tmp_path.c_str());
        if (save_to_tmp_result != 0) {
            throw std::runtime_error("Failed to save model to a temporary file.");
        }

        // Read the XGBoost model from the tmp file
        std::ifstream ifs(tmp_path, std::ios::binary);
        if (!ifs) {
            throw std::runtime_error("Failed to open temporary file for reading model.");
        }
        ifs.seekg(0, std::ios::end);
        std::streampos out_len = ifs.tellg();
        ifs.seekg(0, std::ios::beg);
        std::vector<char> model_data(out_len);
        ifs.read(model_data.data(), out_len);
        ifs.close();

        // Save the XGBoost model size
        std::ofstream ofs(model_path, std::ios::binary);
        bst_ulong out_len_ = static_cast<bst_ulong>(out_len);
        ofs.write(reinterpret_cast<char*>(&out_len_), sizeof(out_len_));

        // Save the XGBoost model
        ofs.write(model_data.data(), out_len);
        ofs.close();

        // Remove the temporary file
        if (std::remove(tmp_path.c_str()) != 0) {
            throw std::runtime_error("Failed to delete the temporary file.");
        }

        model_size_kb = getBinFileSize(model_path);
    }

    // Save the info.json
    {
        std::ofstream ofs(info_json_path);
        ofs << "{\n";
        ofs << "  \"model_type\": \"xgboost\",\n";
        ofs << "  \"max_depth\": " << max_depth << ",\n";
        ofs << "  \"num_boost_round\": " << num_boost_round << ",\n";
        ofs << "  \"eta\": " << eta << ",\n";
        ofs << "  \"model_size_kb\": " << model_size_kb << ",\n";
        ofs << "  \"model_train_time_ms\": " << model_train_time_ms << "\n";
        ofs << "}\n";
        ofs.close();
    }
}

int main(int argc, char* argv[]) {
    std::string X_train_path;
    std::string y_train_path;
    std::string xgboost_folder_path;
    int max_depth = -1;
    int num_boost_round = -1;
    float eta = -1.0;

    int opt;
    while ((opt = getopt(argc, argv, "x:y:o:d:r:e:")) != -1) {
        switch (opt) {
            case 'x':
                X_train_path = optarg;
                break;
            case 'y':
                y_train_path = optarg;
                break;
            case 'o':
                xgboost_folder_path = optarg;
                break;
            case 'd':
                max_depth = std::atoi(optarg);
                break;
            case 'r':
                num_boost_round = std::atoi(optarg);
                break;
            case 'e':
                eta = std::atof(optarg);
                break;
            default:
                std::cerr << "Usage: " << argv[0] << " -x <X_train_path> -y <y_train_path> -o <xgboost_folder_path> -d <max_depth> -r <num_boost_round> -e <eta>\n";
                return 1;
        }
    }

    // Check if all required arguments are provided
    if (X_train_path.empty() || y_train_path.empty() || xgboost_folder_path.empty() || max_depth == -1 || num_boost_round == -1 || eta == -1.0) {
        std::cerr << "Error: Missing required arguments.\n";
        std::cerr << "Usage: " << argv[0] << " -x <X_train_path> -y <y_train_path> -o <xgboost_folder_path> -d <max_depth> -r <num_boost_round> -e <eta>\n";
        return 1;
    }

    // If model_path already exists, skip training the XGBoost model
    std::string model_path = xgboost_folder_path + "/xgboost_model.bin";
    if (std::filesystem::exists(model_path)) {
        std::cout << "XGBoost model already exists at " << model_path << ". Skipping training the XGBoost model.\n";
        return 0;
    }

    std::vector<std::vector<float>> X_train = load_data(X_train_path);
    std::vector<std::vector<float>> y_train = load_data(y_train_path);

    if (X_train_path.find("ember") != std::string::npos) {
        // Sampmle 10% of the data
        std::vector<std::vector<float>> sampled_X_train;
        std::vector<std::vector<float>> sampled_y_train;
        for (int i = 0; i < (int)X_train.size(); i++) {
            if (i % 10 == 0) {
                sampled_X_train.push_back(X_train[i]);
                sampled_y_train.push_back(y_train[i]);
            }
        }
        X_train = sampled_X_train;
        y_train = sampled_y_train;
    }

    train_xgboost(X_train, y_train, xgboost_folder_path, max_depth, num_boost_round, eta);

    return 0;
}
