#ifndef FIND_OPTIMAL_PARAMETERS_H
#define FIND_OPTIMAL_PARAMETERS_H

#include <vector>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <chrono>
#include <stdexcept>
#include <map>
#include <numeric>
#include <cassert>
#include "burger_config.h"

struct Statistics {
    std::vector<float> final_g, final_h;
};

std::vector<std::vector<float>> get_optimal_thresholds_candidates(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats
) {
    int num_group_min = 8;
    int num_group_max = 12;
    float c_min = 1.6f;
    float c_max = 2.5f;
    float c_step = 0.1f;
    
    std::vector<std::vector<float>> optimal_thresholds_candidates;
    for (int num_group = num_group_min; num_group <= num_group_max; num_group++) {
        for (float c = c_min; c <= c_max; c += c_step) {
            // Determine the thresholds
            std::vector<float> thresholds(num_group + 1, 0.0);
            thresholds[0] = -0.1;
            thresholds[num_group] = 1.1;
            int num_negative = static_cast<int>(neg_stats.size());
            float tau = 0.0;
            for (int i = 0; i < num_group; ++i) {
                tau += std::pow(c, i);
            }
            int num_piece = static_cast<int>(num_negative / tau);
            std::vector<float> scores = neg_stats;
            for (int i = 1; i < num_group; i++) {
                if (thresholds[thresholds.size() - i] > 0) {
                    std::vector<float> score_1; for (int j = 0; j < (int)scores.size(); j++) if (scores[j] < thresholds[thresholds.size() - i]) score_1.push_back(scores[j]);
                    if (round(num_piece * std::pow(c, i - 1)) <= score_1.size()) {
                        thresholds[thresholds.size() - (i + 1)] = score_1[score_1.size() - round(num_piece * std::pow(c, i - 1))];
                    } else {
                        thresholds[thresholds.size() - (i + 1)] = 0;
                    }
                } else {
                    thresholds[thresholds.size() - (i + 1)] = 1;
                }
            }
            thresholds.erase(thresholds.begin());
            thresholds.pop_back();
            optimal_thresholds_candidates.push_back(thresholds);
        }
    }

    return optimal_thresholds_candidates;
}

Statistics get_statistics(
    const std::vector<float>& total_pos_stats,
    const std::vector<float>& total_neg_stats,
    const std::vector<float>& th_f,
    int k
) {
    int total_pos_num = total_pos_stats.size();
    int total_neg_num = total_neg_stats.size();

    std::vector<float> pos_stats_list;
    for (int j = 0; j < (int)total_pos_stats.size(); j++) {
        pos_stats_list.push_back(total_pos_stats[j]);
    }
    std::vector<float> neg_stats_list;
    for (int j = 0; j < (int)total_neg_stats.size(); j++) {
        neg_stats_list.push_back(total_neg_stats[j]);
    }
    std::vector<float> pos_cnt(k, 0), neg_cnt(k, 0);
    for (int j = 0; j < (int)total_pos_stats.size(); j++) {
        float st = total_pos_stats[j];
        int idx = std::lower_bound(th_f.begin(), th_f.end(), st) - th_f.begin();
        pos_cnt[idx]++;
    }
    for (int j = 0; j < (int)total_neg_stats.size(); j++) {
        float st = total_neg_stats[j];
        int idx = std::lower_bound(th_f.begin(), th_f.end(), st) - th_f.begin();
        neg_cnt[idx]++;
    }
    std::vector<float> final_g(k, 0), final_h(k, 0);
    for (int j = 0; j < k; j++) {
        final_g[j] = pos_cnt[j] / total_pos_num;
        final_h[j] = neg_cnt[j] / total_neg_num;
    }

    return {
        final_g, final_h
    };
}

BurgerConfig find_optimal_parameters(
    long int n, float bit_size_of_Ada_BF,
    const std::vector<std::vector<float>>& calibration_stats_pos,
    const std::vector<std::vector<float>>& calibration_stats_neg,
    float calibration_time_pos_us, float calibration_time_neg_us,
    float xgboost_model_size_kb
) {
    int D = calibration_stats_pos[0].size();
    int d = D;
    float c = std::log2(M_E);
    float bit_size_of_backup_BF = bit_size_of_Ada_BF - xgboost_model_size_kb * 1000 * 8;

    std::vector<float> total_pos_stats;
    std::vector<float> total_pos_probs;
    for (int i = 0; i < (int)calibration_stats_pos.size(); i++) {
        total_pos_stats.push_back(calibration_stats_pos[i][D - 1]);
        total_pos_probs.push_back(1.0 / (1.0 + std::exp(-calibration_stats_pos[i][D - 1]))); // sigmoid function
    }
    sort(total_pos_stats.begin(), total_pos_stats.end());
    sort(total_pos_probs.begin(), total_pos_probs.end());

    std::vector<float> total_neg_stats;
    std::vector<float> total_neg_probs;
    for (int i = 0; i < (int)calibration_stats_neg.size(); i++) {
        total_neg_stats.push_back(calibration_stats_neg[i][D - 1]);
        total_neg_probs.push_back(1.0 / (1.0 + std::exp(-calibration_stats_neg[i][D - 1]))); // sigmoid function
    }
    sort(total_neg_stats.begin(), total_neg_stats.end());
    sort(total_neg_probs.begin(), total_neg_probs.end());

    float best_expected_fpr = std::numeric_limits<float>::infinity();
    BurgerConfig best_parameters;
    std::vector<float> best_th_f;
    std::vector<float> best_final_g;
    std::vector<float> best_final_h;

    std::vector<std::vector<float>> th_f_candidates = get_optimal_thresholds_candidates(total_pos_probs, total_neg_probs);
    for (int i = 0; i < (int)th_f_candidates.size(); i++) {
        for (int j = 0; j < (int)th_f_candidates[i].size(); j++) {
            // inv sigmoid function
            if (th_f_candidates[i][j] <= 0.0) {
                th_f_candidates[i][j] = -std::numeric_limits<float>::infinity();
            } else if (th_f_candidates[i][j] >= 1.0) {
                th_f_candidates[i][j] = std::numeric_limits<float>::infinity();
            } else {
                th_f_candidates[i][j] = -std::log(1.0 / th_f_candidates[i][j] - 1);
            }
        }
    }

    for (const auto& th_f : th_f_candidates) {
        int k = th_f.size() + 1;
        auto [final_g, final_h] = get_statistics(total_pos_stats, total_neg_stats, th_f, k);
        std::vector<float> R(k, bit_size_of_backup_BF);
        {
            int non_empty_ix = 0;
            for (int i = 0; i < k; ++i) {
                if (final_g[i] > 0) {
                    non_empty_ix = i;
                    break;
                }
            }

            for (int i = 0; i < non_empty_ix; ++i) {
                R[i] = 0.0f;
            }

            float kk = 1.0;
            while (std::abs(std::accumulate(R.begin(), R.end(), 0.0f) - bit_size_of_backup_BF) > 200) {
                float sum_R = std::accumulate(R.begin(), R.end(), 0.0f);
                float dR = bit_size_of_backup_BF * std::pow(0.5, kk);
                if (sum_R > bit_size_of_backup_BF) {
                    R[non_empty_ix] -= dR;
                    R[non_empty_ix] = std::max(1.0f, R[non_empty_ix]);
                } else {
                    R[non_empty_ix] += dR;
                }

                for (int i = non_empty_ix + 1; i < k; ++i) {
                    float count_key_i = n * final_g[i];
                    float count_key_non_empty_ix = n * final_g[non_empty_ix];

                    if (final_h[i] == 0.0f) {
                        R[i] = 0.0f;
                        continue;
                    } else if (final_g[i] == 0.0f) {
                        R[i] = 0.0f;
                        continue;
                    }

                    R[i] = std::max(
                        1.0f,
                        static_cast<float>(count_key_i * (std::log(final_h[non_empty_ix] / final_h[i]) / std::log(0.618f) + R[non_empty_ix] / count_key_non_empty_ix))
                    );
                }

                if (dR <= 1) break;
                kk += 1.0;
            }
        }

        std::vector<float> f(k);
        for (int i = 0; i < k; ++i) {
            if (final_h[i] == 0.0f) {
                f[i] = 1.0f;
            } else if (final_g[i] == 0.0f) {
                f[i] = 1e-18;
            } else if (R[i] == 0.0f) {
                f[i] = 1.0f;
            } else {
                f[i] = std::pow(2.0, - R[i] / (c * n * final_g[i]));
            }
        }

        float mem_sum = xgboost_model_size_kb * 1000 * 8;
        for (int i = 0; i < k; ++i) {
            mem_sum += c * n * final_g[i] * std::log2(1 / f[i]);
        }

        if (abs(mem_sum - bit_size_of_Ada_BF) > 1000) {
            continue;
        }

        float expected_fpr = 0.0f;
        for (int i = 0; i < k; ++i) {
            expected_fpr += final_h[i] * f[i];
        }

        if (expected_fpr < best_expected_fpr) {
            best_expected_fpr = expected_fpr;
            best_parameters = {
                d, k, n, f, th_f, final_g
            };
            best_th_f = th_f;
            best_final_g = final_g;
            best_final_h = final_h;
        }
    }

    return best_parameters;
}


#endif // FIND_OPTIMAL_PARAMETERS_H
