#ifndef FIND_OPTIMAL_PARAMETERS_H
#define FIND_OPTIMAL_PARAMETERS_H

#include <vector>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <chrono>
#include <stdexcept>
#include <map>
#include <numeric>

#include "burger_config.h"

struct Statistics {
    std::vector<float> final_g, final_h;
};

float get_min_stats(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats
) {
    float min_stats = std::numeric_limits<float>::infinity();
    for (int j = 0; j < (int)pos_stats.size(); j++) {
        min_stats = std::min(min_stats, pos_stats[j]);
    }
    for (int j = 0; j < (int)neg_stats.size(); j++) {
        min_stats = std::min(min_stats, neg_stats[j]);
    }
    return min_stats;
}

float get_max_stats(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats
) {
    float max_stats = -std::numeric_limits<float>::infinity();
    for (int j = 0; j < (int)pos_stats.size(); j++) {
        max_stats = std::max(max_stats, pos_stats[j]);
    }
    for (int j = 0; j < (int)neg_stats.size(); j++) {
        max_stats = std::max(max_stats, neg_stats[j]);
    }
    return max_stats;
}

Statistics get_statistics(
    const std::vector<float>& total_pos_stats,
    const std::vector<float>& total_neg_stats,
    float th_f,
    int k
) {
    int total_pos_num = total_pos_stats.size();
    int total_neg_num = total_neg_stats.size();

    std::vector<float> pos_stats_list;
    for (int j = 0; j < (int)total_pos_stats.size(); j++) {
        pos_stats_list.push_back(total_pos_stats[j]);
    }
    std::vector<float> neg_stats_list;
    for (int j = 0; j < (int)total_neg_stats.size(); j++) {
        neg_stats_list.push_back(total_neg_stats[j]);
    }
    std::vector<float> pos_cnt(k, 0), neg_cnt(k, 0);
    for (int j = 0; j < (int)total_pos_stats.size(); j++) {
        float st = total_pos_stats[j];
        if (st < th_f) {
            pos_cnt[0]++;
        } else {
            pos_cnt[1]++;
        }
    }
    for (int j = 0; j < (int)total_neg_stats.size(); j++) {
        float st = total_neg_stats[j];
        if (st < th_f) {
            neg_cnt[0]++;
        } else {
            neg_cnt[1]++;
        }
    }
    std::vector<float> final_g(k, 0), final_h(k, 0);
    for (int j = 0; j < k; j++) {
        final_g[j] = pos_cnt[j] / total_pos_num;
        final_h[j] = neg_cnt[j] / total_neg_num;
    }

    return {
        final_g, final_h
    };
}

BurgerConfig find_optimal_parameters(
    long int n, float F,
    const std::vector<std::vector<float>>& calibration_stats_pos,
    const std::vector<std::vector<float>>& calibration_stats_neg,
    float calibration_time_pos_us, float calibration_time_neg_us,
    float xgboost_model_size_kb
) {
    int D = calibration_stats_pos[0].size();
    int d = D;
    int k = 2;
    float c = std::log2(M_E);
    float total_xgboost_model_size_kb = xgboost_model_size_kb;

    std::vector<float> total_pos_stats;
    for (int i = 0; i < (int)calibration_stats_pos.size(); i++) {
        total_pos_stats.push_back(calibration_stats_pos[i][D - 1]);
    }
    std::vector<float> total_neg_stats;
    for (int i = 0; i < (int)calibration_stats_neg.size(); i++) {
        total_neg_stats.push_back(calibration_stats_neg[i][D - 1]);
    }

    float min_stats = get_min_stats(total_pos_stats, total_neg_stats);
    float max_stats = get_max_stats(total_pos_stats, total_neg_stats);
    std::vector<float> th_f_candidates;
    int N = 100;
    for (int j = 1; j < N; j++) {
        th_f_candidates.push_back(min_stats + (max_stats - min_stats) * j / N);
    }

    float best_mem_sum = std::numeric_limits<float>::infinity();
    BurgerConfig best_parameters;

    for (float th_f : th_f_candidates) {
        auto [final_g, final_h] = get_statistics(total_pos_stats, total_neg_stats, th_f, k);
        std::vector<float> f(k, F);
        while (true) {
            auto tilde_f = [&](float g, float h, float G_f1 = 0, float H_f1 = 0) -> float {
                if (h == 0 || 1 - G_f1 <= 0) return 1.0;
                if (g == 0 || F - H_f1 <= 0) return 1e-18;
                float nume = (F - H_f1) * g;
                float deno = (1 - G_f1) * h;
                return std::min(1.0f, std::max(1e-18f, nume / deno));
            };

            int before_cnt = 0;
            float G_f1 = 0; // sum of g such that FPR = 1
            float H_f1 = 0; // sum of h such that FPR = 1
            for (int i = 0; i < k; ++i) {
                if (f[i] == 1.0f) {
                    before_cnt += 1;
                    G_f1 += final_g[i];
                    H_f1 += final_h[i];
                }
            }
            for (int i = 0; i < k; ++i) {
                f[i] = tilde_f(final_g[i], final_h[i], G_f1, H_f1);
            }

            int after_cnt = std::count(f.begin(), f.end(), 1.0f);
            if (before_cnt == after_cnt) {
                float min_g_h = std::numeric_limits<float>::infinity();
                for (int i = 0; i < k; ++i) {
                    if (f[i] == 1.0f) {
                        min_g_h = std::min(min_g_h, final_g[i] / final_h[i]);
                    }
                }
                break;
            }
        }

        if (f[0] > f[1]) {
            f[0] = F;
            f[1] = F;
        }

        float mem_sum = total_xgboost_model_size_kb * 1000 * 8;
        for (int i = 0; i < k; ++i) {
            mem_sum += c * n * final_g[i] * std::log2(1 / f[i]);
        }
        float expected_fpr = 0.0f;
        for (int i = 0; i < k; ++i) {
            expected_fpr += final_h[i] * f[i];
        }

        if (expected_fpr > F * 1.01) {
            throw std::runtime_error("Expected FPR does not match the target FPR.");
        }

        if (mem_sum < best_mem_sum) {
            best_mem_sum = mem_sum;
            best_parameters = {
                d, k, n, f[1], f[0] / f[1], th_f, final_g[0]
            };
        }
    }

    if (best_mem_sum == std::numeric_limits<float>::infinity()) {
        std::cerr << "Failed to find optimal parameters." << std::endl;
        exit(2);
    }

    return best_parameters;
}


#endif // FIND_OPTIMAL_PARAMETERS_H
