#ifndef FIND_OPTIMAL_PARAMETERS_H
#define FIND_OPTIMAL_PARAMETERS_H

#include <vector>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <chrono>
#include <stdexcept>
#include <map>
#include <numeric>

#include "burger_config.h"

struct Statistics {
    std::vector<float> final_g, final_h;
};

float get_min_stats(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats
) {
    float min_stats = std::numeric_limits<float>::infinity();
    for (int j = 0; j < (int)pos_stats.size(); j++) {
        min_stats = std::min(min_stats, pos_stats[j]);
    }
    for (int j = 0; j < (int)neg_stats.size(); j++) {
        min_stats = std::min(min_stats, neg_stats[j]);
    }
    return min_stats;
}

float get_max_stats(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats
) {
    float max_stats = -std::numeric_limits<float>::infinity();
    for (int j = 0; j < (int)pos_stats.size(); j++) {
        max_stats = std::max(max_stats, pos_stats[j]);
    }
    for (int j = 0; j < (int)neg_stats.size(); j++) {
        max_stats = std::max(max_stats, neg_stats[j]);
    }
    return max_stats;
}

std::pair<std::vector<std::vector<float>>, std::vector<std::vector<int>>> get_dp(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats,
    int N,
    int k
) {
    float min_stats = get_min_stats(pos_stats, neg_stats);
    float max_stats = get_max_stats(pos_stats, neg_stats);
    std::vector<float> th_f_N;
    for (int j = 1; j < N; j++) {
        th_f_N.push_back(min_stats + (max_stats - min_stats) * j / N);
    }
    std::vector<int> pos_cnt(N, 1), neg_cnt(N, 1);
    for (int j = 0; j < (int)pos_stats.size(); j++) {
        float st = pos_stats[j];
        int idx = std::lower_bound(th_f_N.begin(), th_f_N.end(), st) - th_f_N.begin();
        pos_cnt[idx]++;
    }
    for (int j = 0; j < (int)neg_stats.size(); j++) {
        float st = neg_stats[j];
        int idx = std::lower_bound(th_f_N.begin(), th_f_N.end(), st) - th_f_N.begin();
        neg_cnt[idx]++;
    }
    int total_pos_num = std::accumulate(pos_cnt.begin(), pos_cnt.end(), 0);
    int total_neg_num = std::accumulate(neg_cnt.begin(), neg_cnt.end(), 0);
    std::vector<float> pos_ratio(N, 0), neg_ratio(N, 0);
    for (int j = 0; j < N; j++) {
        pos_ratio[j] = (float)pos_cnt[j] / total_pos_num;
        neg_ratio[j] = (float)neg_cnt[j] / total_neg_num;
    }

    std::vector<float> acc_pos_ratio(N + 1, 0.0f);
    std::vector<float> acc_neg_ratio(N + 1, 0.0f);
    for (int j = 0; j < N; j++) {
        acc_pos_ratio[j + 1] = acc_pos_ratio[j] + pos_ratio[j];
        acc_neg_ratio[j + 1] = acc_neg_ratio[j] + neg_ratio[j];
    }

    std::vector<std::vector<float>> dp(k + 1, std::vector<float>(N, -std::numeric_limits<float>::infinity()));
    std::vector<std::vector<int>> pre_dp(k + 1, std::vector<int>(N, -1));
    for(int j = 0; j < N; j++) {
        float pos_sum = acc_pos_ratio[j + 1]; // 0 ~ j
        float neg_sum = acc_neg_ratio[j + 1]; // 0 ~ j
        dp[1][j] = pos_sum * std::log2(pos_sum / neg_sum);
        pre_dp[1][j] = -1;
    }
    for (int i = 2; i <= k; i++) {
        for (int j = i - 1; j < N; j++) {
            for (int l = i - 2; l < j; l++) {
                float pos_sum = acc_pos_ratio[j + 1] - acc_pos_ratio[l + 1]; // (l + 1) ~ j
                float neg_sum = acc_neg_ratio[j + 1] - acc_neg_ratio[l + 1]; // (l + 1) ~ j
                float tmp = dp[i - 1][l] + pos_sum * std::log2(pos_sum / neg_sum);
                if (tmp > dp[i][j]) {
                    dp[i][j] = tmp;
                    pre_dp[i][j] = l;
                }
            }
        }
    }

    return {dp, pre_dp};
}

std::vector<float> get_optimal_thresholds(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats,
    int k
) {
    int N = 100;
    auto [dp, pre_dp] = get_dp(pos_stats, neg_stats, N, k);
    float min_stats = get_min_stats(pos_stats, neg_stats);
    float max_stats = get_max_stats(pos_stats, neg_stats);

    std::vector<int> best_idx(k + 1, 0);
    int now_idx = N - 1;
    best_idx[k] = now_idx;
    for (int i = k - 1; i >= 0; i--) {
        now_idx = pre_dp[i + 1][now_idx];
        best_idx[i] = now_idx;
    }

    std::vector<float> th_f;
    for (int i = 1; i < k; i++) {
        int j = best_idx[i] + 1;
        float th = min_stats + (max_stats - min_stats) * j / N;
        th_f.push_back(th);
    }

    return th_f;
}

std::vector<std::vector<float>> get_optimal_thresholds_candidates(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats,
    int k
) {
    int N = 100;
    auto [dp, pre_dp] = get_dp(pos_stats, neg_stats, N, k);
    float min_stats = get_min_stats(pos_stats, neg_stats);
    float max_stats = get_max_stats(pos_stats, neg_stats);

    std::vector<std::vector<float>> optimal_thresholds_candidates;

    for (int last_idx = N - 2; last_idx >= k - 2; last_idx--) {
        std::vector<int> best_idx(k, 0);
        int now_idx = last_idx;
        best_idx[k - 1] = now_idx;
        for (int i = k - 2; i >= 0; i--) {
            now_idx = pre_dp[i + 1][now_idx];
            best_idx[i] = now_idx;
        }

        std::vector<float> th_f;
        for (int i = 1; i < k; i++) {
            int j = best_idx[i] + 1;
            float th = min_stats + (max_stats - min_stats) * j / N;
            th_f.push_back(th);
        }

        optimal_thresholds_candidates.push_back(th_f);
    }

    return optimal_thresholds_candidates;
}

Statistics get_statistics(
    const std::vector<float>& total_pos_stats,
    const std::vector<float>& total_neg_stats,
    const std::vector<float>& th_f,
    int k
) {
    int total_pos_num = total_pos_stats.size();
    int total_neg_num = total_neg_stats.size();

    std::vector<float> pos_stats_list;
    for (int j = 0; j < (int)total_pos_stats.size(); j++) {
        pos_stats_list.push_back(total_pos_stats[j]);
    }
    std::vector<float> neg_stats_list;
    for (int j = 0; j < (int)total_neg_stats.size(); j++) {
        neg_stats_list.push_back(total_neg_stats[j]);
    }
    std::vector<float> pos_cnt(k, 0), neg_cnt(k, 0);
    for (int j = 0; j < (int)total_pos_stats.size(); j++) {
        float st = total_pos_stats[j];
        int idx = std::lower_bound(th_f.begin(), th_f.end(), st) - th_f.begin();
        pos_cnt[idx]++;
    }
    for (int j = 0; j < (int)total_neg_stats.size(); j++) {
        float st = total_neg_stats[j];
        int idx = std::lower_bound(th_f.begin(), th_f.end(), st) - th_f.begin();
        neg_cnt[idx]++;
    }
    std::vector<float> final_g(k, 0), final_h(k, 0);
    for (int j = 0; j < k; j++) {
        final_g[j] = pos_cnt[j] / total_pos_num;
        final_h[j] = neg_cnt[j] / total_neg_num;
    }

    return {
        final_g, final_h
    };
}

BurgerConfig find_optimal_parameters(
    long int n, float F,
    const std::vector<std::vector<float>>& calibration_stats_pos,
    const std::vector<std::vector<float>>& calibration_stats_neg,
    float calibration_time_pos_us, float calibration_time_neg_us,
    float xgboost_model_size_kb
) {
    int D = calibration_stats_pos[0].size();
    int d = D;
    int k = 8;
    float c = std::log2(M_E);
    float total_xgboost_model_size_kb = xgboost_model_size_kb;

    std::vector<float> total_pos_stats;
    for (int i = 0; i < (int)calibration_stats_pos.size(); i++) {
        total_pos_stats.push_back(calibration_stats_pos[i][D - 1]);
    }
    std::vector<float> total_neg_stats;
    for (int i = 0; i < (int)calibration_stats_neg.size(); i++) {
        total_neg_stats.push_back(calibration_stats_neg[i][D - 1]);
    }

    float best_mem_sum = std::numeric_limits<float>::infinity();
    BurgerConfig best_parameters;
    std::vector<float> best_th_f;
    std::vector<float> best_final_g;
    std::vector<float> best_final_h;

    std::vector<std::vector<float>> th_f_candidates = get_optimal_thresholds_candidates(total_pos_stats, total_neg_stats, k);
    for (const auto& th_f : th_f_candidates) {
        auto [final_g, final_h] = get_statistics(total_pos_stats, total_neg_stats, th_f, k);
        std::vector<float> f(k, F);
        while (true) {
            auto tilde_f = [&](float g, float h, float G_f1 = 0, float H_f1 = 0) -> float {
                if (h == 0 || 1 - G_f1 <= 0) return 1.0;
                if (g == 0 || F - H_f1 <= 0) return 1e-18;
                float nume = (F - H_f1) * g;
                float deno = (1 - G_f1) * h;
                return std::min(1.0f, std::max(1e-18f, nume / deno));
            };

            int before_cnt = 0;
            float G_f1 = 0; // sum of g such that FPR = 1
            float H_f1 = 0; // sum of h such that FPR = 1
            for (int i = 0; i < k; ++i) {
                if (f[i] == 1.0f) {
                    before_cnt += 1;
                    G_f1 += final_g[i];
                    H_f1 += final_h[i];
                }
            }
            for (int i = 0; i < k; ++i) {
                f[i] = tilde_f(final_g[i], final_h[i], G_f1, H_f1);
            }

            int after_cnt = std::count(f.begin(), f.end(), 1.0f);
            if (before_cnt == after_cnt) {
                float min_g_h = std::numeric_limits<float>::infinity();
                for (int i = 0; i < k; ++i) {
                    if (f[i] == 1.0f) {
                        min_g_h = std::min(min_g_h, final_g[i] / final_h[i]);
                    }
                }
                break;
            }
        }

        float mem_sum = total_xgboost_model_size_kb * 1000 * 8;
        for (int i = 0; i < k; ++i) {
            mem_sum += c * n * final_g[i] * std::log2(1 / f[i]);
        }
        float expected_fpr = 0.0f;
        for (int i = 0; i < k; ++i) {
            expected_fpr += final_h[i] * f[i];
        }

        if (expected_fpr > F * 1.01) {
            throw std::runtime_error("Expected FPR does not match the target FPR.");
        }

        if (mem_sum < best_mem_sum) {
            best_mem_sum = mem_sum;
            best_parameters = {
                d, k, n, f, th_f, final_g
            };
            best_th_f = th_f;
            best_final_g = final_g;
            best_final_h = final_h;
        }
    }

    return best_parameters;
}


#endif // FIND_OPTIMAL_PARAMETERS_H
