#ifndef FIND_OPTIMAL_PARAMETERS_H
#define FIND_OPTIMAL_PARAMETERS_H

#include <vector>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <chrono>
#include <stdexcept>
#include <map>
#include <tuple>

#include "burger_config.h"

bool equal_1d_vector(const std::vector<float>& a, const std::vector<float>& b) {
    if (a.size() != b.size()) {
        return false;
    }
    for (int i = 0; i < (int)a.size(); i++) {
        if (std::abs(a[i] - b[i]) > 1e-4) {
            return false;
        }
    }
    return true;
}

bool equal_2d_vector(const std::vector<std::vector<float>>& a, const std::vector<std::vector<float>>& b) {
    if (a.size() != b.size()) {
        return false;
    }
    for (int i = 0; i < (int)a.size(); i++) {
        if (a[i].size() != b[i].size()) {
            return false;
        }
        for (int j = 0; j < (int)a[i].size(); j++) {
            if (std::abs(a[i][j] - b[i][j]) > 1e-4) {
                std::cout << "i: " << i << " j: " << j << " a[i][j]: " << a[i][j] << " b[i][j]: " << b[i][j] << std::endl;
                return false;
            }
        }
    }
    return true;
}

struct Statistics {
    std::vector<float> raw_g, raw_h, branch_raw_g, branch_raw_h;
    std::vector<std::vector<float>> final_raw_g, final_raw_h, th_f;
    std::vector<std::vector<float>> final_raw_g_N, final_raw_h_N, th_f_N;
};

#include "bloom.h"
bool warn_too_many_false_positives = false;
float get_time_per_item_us_bf(long int n, float F) {
    int bits_per_item = std::max(1, (int)(-std::log(F) / std::log(2) / std::log(2) + 0.5));
    bloomfilter::BloomFilter<uint64_t, false> bf(n + 1, bits_per_item);
    for (int i = 0; i < n; i++) {
        bf.Add(i);
    }
    int test_num = 1000000;
    int false_positive_num = 0;
    auto start = std::chrono::high_resolution_clock::now();
    for (int i = n; i < n + test_num; i++) {
        if (bf.Contain(i) == bloomfilter::Status::Ok) {
            false_positive_num++;
        }
    }
    auto end = std::chrono::high_resolution_clock::now();
    if (false_positive_num >= F * test_num * 2.0 && warn_too_many_false_positives) {
        std::cout << "[WARN] Too many false positives." << std::endl;
    }
    float total_ns = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
    return total_ns / test_num / 1000;
}
float get_mem_usage_bits_bf(long int n, float F) {
    int bits_per_item = std::max(1, (int)(-std::log(F) / std::log(2) / std::log(2) + 0.5));
    return n * bits_per_item;
}

std::tuple<std::vector<float>, std::vector<float>, std::vector<float>> get_optimal_thresholds_N(
    const std::vector<float>& pos_ratio,
    const std::vector<float>& neg_ratio,
    const std::vector<float>& th_f_N_j,
    int k
) {
    int N = pos_ratio.size();
    std::vector<float> acc_pos_ratio(N + 1, 0.0f);
    std::vector<float> acc_neg_ratio(N + 1, 0.0f);
    for (int j = 0; j < N; j++) {
        acc_pos_ratio[j + 1] = acc_pos_ratio[j] + pos_ratio[j];
        acc_neg_ratio[j + 1] = acc_neg_ratio[j] + neg_ratio[j];
    }

    std::vector<std::vector<float>> dp(k + 1, std::vector<float>(N, -std::numeric_limits<float>::infinity()));
    std::vector<std::vector<int>> pre_dp(k + 1, std::vector<int>(N, -1));
    for(int j = 0; j < N; j++) {
        float pos_sum = acc_pos_ratio[j + 1]; // 0 ~ j
        float neg_sum = acc_neg_ratio[j + 1]; // 0 ~ j
        dp[1][j] = pos_sum * std::log2(pos_sum / neg_sum);
        pre_dp[1][j] = -1;
    }
    for (int i = 2; i <= k; i++) {
        for (int j = i - 1; j < N; j++) {
            for (int l = i - 2; l < j; l++) {
                float pos_sum = acc_pos_ratio[j + 1] - acc_pos_ratio[l + 1]; // (l + 1) ~ j
                float neg_sum = acc_neg_ratio[j + 1] - acc_neg_ratio[l + 1]; // (l + 1) ~ j
                float tmp = dp[i - 1][l] + pos_sum * std::log2(pos_sum / neg_sum);
                if (tmp > dp[i][j]) {
                    dp[i][j] = tmp;
                    pre_dp[i][j] = l;
                }
            }
        }
    }

    std::vector<int> best_idx(k + 1, 0);
    int now_idx = N - 1;
    best_idx[k] = now_idx;
    for (int i = k - 1; i >= 0; i--) {
        now_idx = pre_dp[i + 1][now_idx];
        best_idx[i] = now_idx;
    }

    if (!std::is_sorted(best_idx.begin(), best_idx.end())) {
        for (int i = 0; i < (int)best_idx.size(); i++) {
            if (i == (int)best_idx.size() - 1) {
                best_idx[i] = N - 1;
            } else {
                best_idx[i] = i - 1;
            }
        }
    }

    std::vector<float> th_f;
    for (int i = 1; i < k; i++) {
        int j = best_idx[i] + 1;
        float th = th_f_N_j[j - 1];
        th_f.push_back(th);
    }

    std::vector<float> pos_ratio_k(k), neg_ratio_k(k);
    for (int i = 0; i < k; i++) {
        pos_ratio_k[i] = acc_pos_ratio[best_idx[i + 1] + 1] - acc_pos_ratio[best_idx[i] + 1];
        neg_ratio_k[i] = acc_neg_ratio[best_idx[i + 1] + 1] - acc_neg_ratio[best_idx[i] + 1];
    }

    return {pos_ratio_k, neg_ratio_k, th_f};
}

std::vector<float> get_optimal_thresholds(
    const std::vector<float>& pos_stats,
    const std::vector<float>& neg_stats,
    float min_stats, float max_stats,
    int k
) {
    int N = 100;
    std::vector<float> th_f_N;
    for (int j = 1; j < N; j++) {
        th_f_N.push_back(min_stats + (max_stats - min_stats) * j / N);
    }
    std::vector<float> pos_cnt(N, 0), neg_cnt(N, 0);
    for (int j = 0; j < (int)pos_stats.size(); j++) {
        float st = pos_stats[j];
        int idx = std::lower_bound(th_f_N.begin(), th_f_N.end(), st) - th_f_N.begin();
        pos_cnt[idx] += 1;
    }
    for (int j = 0; j < (int)neg_stats.size(); j++) {
        float st = neg_stats[j];
        int idx = std::lower_bound(th_f_N.begin(), th_f_N.end(), st) - th_f_N.begin();
        neg_cnt[idx] += 1;
    }
    int total_pos_num = std::accumulate(pos_cnt.begin(), pos_cnt.end(), 0);
    int total_neg_num = std::accumulate(neg_cnt.begin(), neg_cnt.end(), 0);
    std::vector<float> pos_ratio(N, 0), neg_ratio(N, 0);
    for (int j = 0; j < N; j++) {
        pos_ratio[j] = (float)pos_cnt[j] / total_pos_num;
        neg_ratio[j] = (float)neg_cnt[j] / total_neg_num;
    }

    std::vector<float> acc_pos_ratio(N + 1, 0.0f);
    std::vector<float> acc_neg_ratio(N + 1, 0.0f);
    for (int j = 0; j < N; j++) {
        acc_pos_ratio[j + 1] = acc_pos_ratio[j] + pos_ratio[j];
        acc_neg_ratio[j + 1] = acc_neg_ratio[j] + neg_ratio[j];
    }

    std::vector<std::vector<float>> dp(k + 1, std::vector<float>(N, -std::numeric_limits<float>::infinity()));
    std::vector<std::vector<int>> pre_dp(k + 1, std::vector<int>(N, -1));
    for(int j = 0; j < N; j++) {
        float pos_sum = acc_pos_ratio[j + 1]; // 0 ~ j
        float neg_sum = acc_neg_ratio[j + 1]; // 0 ~ j
        dp[1][j] = pos_sum * std::log2(pos_sum / neg_sum);
        pre_dp[1][j] = -1;
    }
    for (int i = 2; i <= k; i++) {
        for (int j = i - 1; j < N; j++) {
            for (int l = i - 2; l < j; l++) {
                float pos_sum = acc_pos_ratio[j + 1] - acc_pos_ratio[l + 1]; // (l + 1) ~ j
                float neg_sum = acc_neg_ratio[j + 1] - acc_neg_ratio[l + 1]; // (l + 1) ~ j
                float tmp = dp[i - 1][l] + pos_sum * std::log2(pos_sum / neg_sum);
                if (tmp > dp[i][j]) {
                    dp[i][j] = tmp;
                    pre_dp[i][j] = l;
                }
            }
        }
    }

    std::vector<int> best_idx(k + 1, 0);
    int now_idx = N - 1;
    best_idx[k] = now_idx;
    for (int i = k - 1; i >= 0; i--) {
        now_idx = pre_dp[i + 1][now_idx];
        best_idx[i] = now_idx;
    }

    std::vector<float> th_f;
    for (int i = 1; i < k; i++) {
        int j = best_idx[i] + 1;
        float th = min_stats + (max_stats - min_stats) * j / N;
        th_f.push_back(th);
    }
    return th_f;
}

/*
    th_b[i] = top-alpha (ratio) of neg_stats
*/
struct Th_B_Candidate {
    float alpha;
    std::vector<float> th_b;
};

std::vector<Th_B_Candidate> get_th_b_candidates(
    const std::vector<std::vector<float>>& total_neg_stats
) {
    std::vector<Th_B_Candidate> th_b_candidates;

    std::vector<float> alpha_0_candidates = {
        0.1,
        0.01,
        0.001,
        0.0001,
        0.0, -1.0
    };

    // sort total_neg_stats
    int D = total_neg_stats[0].size();
    int total_neg_num = total_neg_stats.size();
    std::vector<std::vector<float>> sorted_neg_stats(D, std::vector<float>(total_neg_num));
    for (int i = 0; i < D; i++) {
        for (int j = 0; j < total_neg_num; j++) {
            sorted_neg_stats[i][j] = total_neg_stats[j][i];
        }
        // ascending order
        std::sort(sorted_neg_stats[i].begin(), sorted_neg_stats[i].end(), std::greater<float>());
    }

    for (float alpha : alpha_0_candidates) {
        std::vector<float> th_b;
        for (int i = 0; i < D - 1; i++) {
            if (alpha == -1.0) {
                th_b.push_back(std::numeric_limits<float>::infinity());
                continue;
            } else {
                int idx = std::min(static_cast<int>(total_neg_num - 1), static_cast<int>(total_neg_num * alpha));
                th_b.push_back(sorted_neg_stats[i][idx]);
            }
        }
        th_b_candidates.push_back({alpha, th_b});
    }
    for (int i = 0; i < D - 1; i++) {
        std::vector<float> th_b_row;
        for (int j = 0; j < (int)th_b_candidates.size(); j++) {
            th_b_row.push_back(th_b_candidates[j].th_b[i]);
        }
        assert(std::is_sorted(th_b_row.begin(), th_b_row.end()));
    }

    return th_b_candidates;
}


Statistics get_statistics(
    const std::vector<std::vector<float>>& total_pos_stats,
    const std::vector<std::vector<float>>& total_neg_stats,
    const std::vector<float>& th_b,
    int k
) {
    int D = total_pos_stats[0].size();
    int total_pos_num = total_pos_stats.size();
    int total_neg_num = total_neg_stats.size();

    std::vector<std::vector<float>> pos_stats = total_pos_stats;
    std::vector<std::vector<float>> neg_stats = total_neg_stats;

    std::vector<float> raw_g, raw_h, branch_raw_g, branch_raw_h;
    std::vector<std::vector<float>> final_raw_g, final_raw_h, th_f;
    std::vector<std::vector<float>> final_raw_g_N, final_raw_h_N, th_f_N;

    std::vector<float> min_stats_list(D);
    std::vector<float> max_stats_list(D);
    for (int j = 0; j < D; j++) {
        min_stats_list[j] = std::numeric_limits<float>::infinity();
        max_stats_list[j] = -std::numeric_limits<float>::infinity();
    }
    for (int i = 0; i < (int)pos_stats.size(); i++) {
        for (int j = 0; j < D; j++) {
            min_stats_list[j] = std::min(min_stats_list[j], pos_stats[i][j]);
            max_stats_list[j] = std::max(max_stats_list[j], pos_stats[i][j]);
        }
    }
    for (int i = 0; i < (int)neg_stats.size(); i++) {
        for (int j = 0; j < D; j++) {
            min_stats_list[j] = std::min(min_stats_list[j], neg_stats[i][j]);
            max_stats_list[j] = std::max(max_stats_list[j], neg_stats[i][j]);
        }
    }

    int N = 100;

    for (int i = 0; i < D; i++) {
        // Statistics for the i-th layer
        raw_g.push_back(static_cast<float>(pos_stats.size()) / total_pos_num);
        raw_h.push_back(static_cast<float>(neg_stats.size()) / total_neg_num);

        // Statistics for the case where the i-th layer is the "final" layer (N sements)
        {
            std::vector<float> th_f_N_i;
            {
                for (int j = 1; j < N; j++) {
                    th_f_N_i.push_back(min_stats_list[i] + (max_stats_list[i] - min_stats_list[i]) * j / N);
                }
            }
            th_f_N.push_back(th_f_N_i);

            std::vector<float> final_raw_g_N_i(N, 0), final_raw_h_N_i(N, 0);
            {
                std::vector<float> pos_cnt(N, 0), neg_cnt(N, 0);
                for (int j = 0; j < (int)pos_stats.size(); j++) {
                    float st = pos_stats[j][i];
                    int idx = std::lower_bound(th_f_N_i.begin(), th_f_N_i.end(), st) - th_f_N_i.begin();
                    pos_cnt[idx]++;
                }
                for (int j = 0; j < (int)neg_stats.size(); j++) {
                    float st = neg_stats[j][i];
                    int idx = std::lower_bound(th_f_N_i.begin(), th_f_N_i.end(), st) - th_f_N_i.begin();
                    neg_cnt[idx]++;
                }
                for (int j = 0; j < N; j++) {
                    final_raw_g_N_i[j] = pos_cnt[j] / total_pos_num;
                    final_raw_h_N_i[j] = neg_cnt[j] / total_neg_num;
                }
            }
            final_raw_g_N.push_back(final_raw_g_N_i);
            final_raw_h_N.push_back(final_raw_h_N_i);
        }

        // Statistics for the case where the i-th layer is the "final" layer
        {
            std::vector<float> th_f_i;
            {
                std::vector<float> pos_stats_list(pos_stats.size());
                for (int j = 0; j < (int)pos_stats.size(); j++) {
                    pos_stats_list[j] = pos_stats[j][i];
                }
                std::vector<float> neg_stats_list(neg_stats.size());
                for (int j = 0; j < (int)neg_stats.size(); j++) {
                    neg_stats_list[j] = neg_stats[j][i];
                }
                th_f_i = get_optimal_thresholds(pos_stats_list, neg_stats_list, min_stats_list[i], max_stats_list[i], k);
            }
            th_f.push_back(th_f_i);

            std::vector<float> final_raw_g_i(k, 0), final_raw_h_i(k, 0);
            {
                std::vector<float> pos_cnt(k, 0), neg_cnt(k, 0);
                for (int j = 0; j < (int)pos_stats.size(); j++) {
                    float st = pos_stats[j][i];
                    int idx = std::lower_bound(th_f_i.begin(), th_f_i.end(), st) - th_f_i.begin();
                    pos_cnt[idx]++;
                }
                for (int j = 0; j < (int)neg_stats.size(); j++) {
                    float st = neg_stats[j][i];
                    int idx = std::lower_bound(th_f_i.begin(), th_f_i.end(), st) - th_f_i.begin();
                    neg_cnt[idx]++;
                }
                for (int j = 0; j < k; j++) {
                    final_raw_g_i[j] = pos_cnt[j] / total_pos_num;
                    final_raw_h_i[j] = neg_cnt[j] / total_neg_num;
                }
            }
            final_raw_g.push_back(final_raw_g_i);
            final_raw_h.push_back(final_raw_h_i);
        }

        if (i < D - 1) {
            int pre_pos_stats_num = pos_stats.size();
            int pre_neg_stats_num = neg_stats.size();

            auto pos_it = std::remove_if(pos_stats.begin(), pos_stats.end(), [&](const std::vector<float>& v) { return v[i] > th_b[i]; });
            auto neg_it = std::remove_if(neg_stats.begin(), neg_stats.end(), [&](const std::vector<float>& v) { return v[i] > th_b[i]; });

            pos_stats.erase(pos_it, pos_stats.end());
            neg_stats.erase(neg_it, neg_stats.end());

            int branch_pos_num = pre_pos_stats_num - pos_stats.size();
            int branch_neg_num = pre_neg_stats_num - neg_stats.size();

            branch_raw_g.push_back(static_cast<float>(branch_pos_num) / total_pos_num);
            branch_raw_h.push_back(static_cast<float>(branch_neg_num) / total_neg_num);
        }
    }

    return {raw_g, raw_h, branch_raw_g, branch_raw_h, final_raw_g, final_raw_h, th_f, final_raw_g_N, final_raw_h_N, th_f_N};
}

std::vector<Statistics> get_statistics_th(
    const std::vector<std::vector<float>>& total_pos_stats,
    const std::vector<std::vector<float>>& total_neg_stats,
    std::vector<Th_B_Candidate> th_b_candidates, 
    int k
) {
    // Assert the monotonicity of th_b_candidates
    int D = total_pos_stats[0].size();
    int m = th_b_candidates.size();
    int total_pos_num = total_pos_stats.size();
    int total_neg_num = total_neg_stats.size();
    for (int j = 0; j < D - 1; j++) {
        std::vector<float> th_b_row;
        for (int i = 0; i < m; i++) {
            th_b_row.push_back(th_b_candidates[i].th_b[j]);
        }
        assert(std::is_sorted(th_b_row.begin(), th_b_row.end()));
    }

    auto get_dp = [&](const std::vector<std::vector<float>>& total_stats) -> std::vector<std::vector<int>> {
        int n = total_stats.size();
        std::vector<std::vector<int>> dp(n, std::vector<int>(D, m - 1));
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < D; j++) {
                if (j == 0) {
                    dp[i][j] = m - 1;
                    while (dp[i][j] > 0 && total_stats[i][j] <= th_b_candidates[dp[i][j] - 1].th_b[j]) {
                        dp[i][j]--;
                    }
                } else {
                    dp[i][j] = dp[i][j - 1];
                    while (dp[i][j] < m - 1 && total_stats[i][j] > th_b_candidates[dp[i][j]].th_b[j]) {
                        dp[i][j]++;
                    }
                }
            }
        }
        return dp;
    };

    std::vector<std::vector<int>> pos_dp = get_dp(total_pos_stats);
    std::vector<std::vector<int>> neg_dp = get_dp(total_neg_stats);

    auto get_cnt = [&](const std::vector<std::vector<int>>& dp) -> std::vector<std::vector<int>> {
        std::vector<std::vector<int>> cnt(m, std::vector<int>(D, 0));
        for (int i = 0; i < (int)dp.size(); i++) {
            for (int j = 0; j < D; j++) {
                cnt[dp[i][j]][j]++;
            }
        }
        for (int i = 0; i < m - 1; i ++) {
            for (int j = 0; j < D; j++) {
                cnt[i + 1][j] += cnt[i][j];
            }
        }
        return cnt;
    };

    int N = 100;
    std::vector<float> min_stats_list(D);
    std::vector<float> max_stats_list(D);
    {
        for (int j = 0; j < D; j++) {
            min_stats_list[j] = std::numeric_limits<float>::infinity();
            max_stats_list[j] = -std::numeric_limits<float>::infinity();
        }
        for (int i = 0; i < (int)total_pos_stats.size(); i++) {
            for (int j = 0; j < D; j++) {
                min_stats_list[j] = std::min(min_stats_list[j], total_pos_stats[i][j]);
                max_stats_list[j] = std::max(max_stats_list[j], total_pos_stats[i][j]);
            }
        }
        for (int i = 0; i < (int)total_neg_stats.size(); i++) {
            for (int j = 0; j < D; j++) {
                min_stats_list[j] = std::min(min_stats_list[j], total_neg_stats[i][j]);
                max_stats_list[j] = std::max(max_stats_list[j], total_neg_stats[i][j]);
            }
        }
    }
    std::vector<std::vector<float>> th_f_N(D, std::vector<float>(N - 1));
    for (int j = 0; j < D; j++) {
        for (int l = 1; l < N; l++) {
            th_f_N[j][l - 1] = min_stats_list[j] + (max_stats_list[j] - min_stats_list[j]) * l / N;
        }
    }

    auto get_cnt_per_segment = [&](const std::vector<std::vector<float>>& total_stats, const std::vector<std::vector<int>>& dp) -> std::vector<std::vector<std::vector<int>>> {
        std::vector<std::vector<std::vector<int>>> cnt_per_segment(m, std::vector<std::vector<int>>(D, std::vector<int>(N, 0)));
        for (int i = 0; i < (int)dp.size(); i++) {
            for (int j = 0; j < D; j++) {
                int seg = std::lower_bound(th_f_N[j].begin(), th_f_N[j].end(), total_stats[i][j]) - th_f_N[j].begin();
                if (j == 0) {
                    cnt_per_segment[0][j][seg]++;
                } else {
                    cnt_per_segment[dp[i][j - 1]][j][seg]++;
                }
            }
        }
        for (int i = 0; i < m - 1; i++) {
            for (int j = 0; j < D; j++) {
                for (int l = 0; l < N; l++) {
                    cnt_per_segment[i + 1][j][l] += cnt_per_segment[i][j][l];
                }
            }
        }
        return cnt_per_segment;
    };

    std::vector<std::vector<int>> pos_cnt = get_cnt(pos_dp);
    std::vector<std::vector<int>> neg_cnt = get_cnt(neg_dp);
    std::vector<std::vector<std::vector<int>>> pos_cnt_per_segment = get_cnt_per_segment(total_pos_stats, pos_dp);
    std::vector<std::vector<std::vector<int>>> neg_cnt_per_segment = get_cnt_per_segment(total_neg_stats, neg_dp);

    std::vector<Statistics> statistics_list;
    for (int i = 0; i < m; i++) {
        std::vector<int> raw_pos_cnt(D, 0), raw_neg_cnt(D, 0);
        std::vector<int> branch_pos_cnt(D - 1, 0), branch_neg_cnt(D - 1, 0);
        std::vector<std::vector<int>> final_pos_cnt(D, std::vector<int>(N, 0)), final_neg_cnt(D, std::vector<int>(N, 0));
        for (int j = 0; j < D; j++) {
            if (j == 0) {
                raw_pos_cnt[j] = total_pos_num;
                raw_neg_cnt[j] = total_neg_num;
            } else {
                raw_pos_cnt[j] = pos_cnt[i][j - 1];
                raw_neg_cnt[j] = neg_cnt[i][j - 1];
            }
            if (j > 0) {
                branch_pos_cnt[j - 1] = raw_pos_cnt[j - 1] - raw_pos_cnt[j];
                branch_neg_cnt[j - 1] = raw_neg_cnt[j - 1] - raw_neg_cnt[j];
            }
            for (int l = 0; l < N; l++) {
                final_pos_cnt[j][l] = pos_cnt_per_segment[i][j][l];
                final_neg_cnt[j][l] = neg_cnt_per_segment[i][j][l];
            }
        }
        std::vector<float> raw_g, raw_h, branch_raw_g, branch_raw_h;
        std::vector<std::vector<float>> final_raw_g, final_raw_h, th_f;
        std::vector<std::vector<float>> final_raw_g_N, final_raw_h_N;
        for (int j = 0; j < D; j++) {
            raw_g.push_back(static_cast<float>(raw_pos_cnt[j]) / total_pos_num);
            raw_h.push_back(static_cast<float>(raw_neg_cnt[j]) / total_neg_num);
            if (j < D - 1) {
                branch_raw_g.push_back(static_cast<float>(branch_pos_cnt[j]) / total_pos_num);
                branch_raw_h.push_back(static_cast<float>(branch_neg_cnt[j]) / total_neg_num);
            }
            std::vector<float> final_raw_g_N_j(N, 0), final_raw_h_N_j(N, 0);
            for (int l = 0; l < N; l++) {
                final_raw_g_N_j[l] = static_cast<float>(final_pos_cnt[j][l]) / total_pos_num;
                final_raw_h_N_j[l] = static_cast<float>(final_neg_cnt[j][l]) / total_neg_num;
            }
            final_raw_g_N.push_back(final_raw_g_N_j);
            final_raw_h_N.push_back(final_raw_h_N_j);
            auto [final_raw_g_j, final_raw_h_j, th_f_j] = get_optimal_thresholds_N(final_raw_g_N_j, final_raw_h_N_j, th_f_N[j], k);            
            final_raw_g.push_back(final_raw_g_j);
            final_raw_h.push_back(final_raw_h_j);
            th_f.push_back(th_f_j);
        }
        Statistics st = {raw_g, raw_h, branch_raw_g, branch_raw_h, final_raw_g, final_raw_h, th_f, final_raw_g_N, final_raw_h_N, th_f_N};
        statistics_list.push_back(st);
    }

    return statistics_list;
}

BurgerConfig find_optimal_parameters(
    long int n, float F, float lambda_, float mu_,
    const std::vector<std::vector<float>>& calibration_stats_pos,
    const std::vector<std::vector<float>>& calibration_stats_neg,
    float calibration_time_pos_us, float calibration_time_neg_us,
    float xgboost_model_size_kb
) {
    std::vector<std::vector<float>> total_pos_stats = calibration_stats_pos;
    std::vector<std::vector<float>> total_neg_stats = calibration_stats_neg;
    int D = total_pos_stats[0].size();

    // Hyperparameters
    float total_xgboost_model_size_kb = xgboost_model_size_kb;
    float time_per_item_us = (calibration_time_neg_us + calibration_time_pos_us) / (calibration_stats_pos.size() + calibration_stats_neg.size());
    // std::cout << "[INFO] Total XGBoost model size: " << total_xgboost_model_size_kb << " KB" << std::endl;
    // std::cout << "[INFO] Time per item: " << time_per_item_us << " us" << std::endl;
    // std::cout << "[INFO] time_per_item_us / D: " << time_per_item_us / D << " us" << std::endl;
    std::vector<float> size_ml_bits(D, (total_xgboost_model_size_kb * 1000 * 8) / D);   // bits per layer
    std::vector<float> time_ml_us(D, time_per_item_us / D);                  // micro seconds per layer and item
    // std::cout << "[INFO] Size of ML model: " << size_ml_bits[0] << " bits" << std::endl;
    // std::cout << "[INFO] Time of ML model: " << time_ml_us[0] << " us" << std::endl;
    int k = 8;
    float p_base = 0.5;
    int P = 20;
    float c = std::log2(M_E);
    float mem_usage_bits_bf = get_mem_usage_bits_bf(n, F);
    float time_per_item_us_bf = get_time_per_item_us_bf(n, F);
    // std::cout << "[INFO] Memory usage per item (Bloom Filter): " << mem_usage_bits_bf << " bits" << std::endl;
    // std::cout << "[INFO] Time per item (Bloom Filter): " << time_per_item_us_bf << " us" << std::endl;

    // Get Statistics
    float best_objective = std::numeric_limits<int>::max();
    int best_d = 0;
    std::vector<float> best_branch_g, best_branch_h;
    std::vector<float> best_final_g, best_final_h;
    std::vector<float> best_b, best_t, best_f;

    std::vector<float> best_raw_g, best_raw_h;
    std::vector<float> best_branch_raw_g;
    std::vector<float> best_branch_raw_h;
    std::vector<float> best_th_b;
    std::vector<std::vector<float>> best_final_raw_g, best_final_raw_h, best_th_f;

    float best_alpha = -1.0;
    auto get_th_b_candidate_start = std::chrono::high_resolution_clock::now();
    std::vector<Th_B_Candidate> th_b_candidates = get_th_b_candidates(total_neg_stats);
    std::cout << "[INFO] Get th_b candidates time: " << std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - get_th_b_candidate_start).count() << " ms" << std::endl;

    auto get_statistics_th_start = std::chrono::high_resolution_clock::now();
    std::vector<Statistics> statistics_list = get_statistics_th(total_pos_stats, total_neg_stats, th_b_candidates, k);
    std::cout << "[INFO] Get statistics time: " << std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - get_statistics_th_start).count() << " ms" << std::endl;

    for (int th_b_candidate_idx = 0; th_b_candidate_idx < (int)th_b_candidates.size(); th_b_candidate_idx++) {
        Th_B_Candidate th_b_candidate = th_b_candidates[th_b_candidate_idx];
        float alpha = th_b_candidate.alpha;
        std::vector<float> th_b = th_b_candidate.th_b;

        auto raw_g = statistics_list[th_b_candidate_idx].raw_g;
        auto raw_h = statistics_list[th_b_candidate_idx].raw_h;
        auto branch_raw_g = statistics_list[th_b_candidate_idx].branch_raw_g;
        auto branch_raw_h = statistics_list[th_b_candidate_idx].branch_raw_h;
        auto final_raw_g = statistics_list[th_b_candidate_idx].final_raw_g;
        auto final_raw_h = statistics_list[th_b_candidate_idx].final_raw_h;
        auto th_f = statistics_list[th_b_candidate_idx].th_f;
        auto final_raw_g_N = statistics_list[th_b_candidate_idx].final_raw_g_N;
        auto final_raw_h_N = statistics_list[th_b_candidate_idx].final_raw_h_N;
        auto th_f_N = statistics_list[th_b_candidate_idx].th_f_N;

        // Get Statistics
        // auto start = std::chrono::high_resolution_clock::now();
        // const auto [raw_g_gt, raw_h_gt, branch_raw_g_gt, branch_raw_h_gt, final_raw_g_gt, final_raw_h_gt, th_f_gt, final_raw_g_N_gt, final_raw_h_N_gt, th_f_N_gt] = get_statistics(total_pos_stats, total_neg_stats, th_b, k);
        // get_statistics_time_ms_sum += std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - start).count();
        // assert(equal_1d_vector(raw_g, raw_g_gt));
        // assert(equal_1d_vector(raw_h, raw_h_gt));
        // assert(equal_1d_vector(branch_raw_g, branch_raw_g_gt));
        // assert(equal_1d_vector(branch_raw_h, branch_raw_h_gt));
        // assert(equal_2d_vector(final_raw_g, final_raw_g_gt));
        // assert(equal_2d_vector(final_raw_h, final_raw_h_gt));
        // assert(equal_2d_vector(th_f, th_f_gt));
        // assert(equal_2d_vector(final_raw_g_N, final_raw_g_N_gt));
        // assert(equal_2d_vector(final_raw_h_N, final_raw_h_N_gt));
        // assert(equal_2d_vector(th_f_N, th_f_N_gt));
        // std::cout << "OK" << std::endl;

        assert(D == raw_g.size());
        assert(D == raw_h.size());
        assert(D - 1 == branch_raw_g.size());
        assert(D - 1 == branch_raw_h.size());
        assert(D - 1 == th_b.size());
        assert(D == final_raw_g.size());
        assert(k == final_raw_g[0].size());
        assert(D == final_raw_h.size());
        assert(k == final_raw_h[0].size());
        assert(D == th_f.size());
        assert(k - 1 == th_f[0].size());
        assert(D == final_raw_g_N.size());
        assert(N == final_raw_g_N[0].size());
        assert(D == final_raw_h_N.size());
        assert(N == final_raw_h_N[0].size());
        assert(D == th_f_N.size());
        assert(N - 1 == th_f_N[0].size());

        // Define the tilde_f function
        auto tilde_f = [&](float g, float h, float G_f1 = 0, float H_f1 = 0) -> float {
            if (h == 0 || 1 - G_f1 <= 0) return 1.0;
            if (g == 0 || F - H_f1 <= 0) return 1e-18;
            float nume = (F - H_f1) * g;
            float deno = (1 - G_f1) * h;
            return std::min(1.0f, std::max(1e-18f, nume / deno));
        };

        // Define the objective function
        auto objective_func = [&](int d, const std::vector<float>& b, const std::vector<float>& t, const std::vector<float>& f) -> float {
            assert(d - 1 == b.size());
            assert(d == t.size());
            assert(k == f.size());

            float mem_sum_bits = 0.0f;
            {
                for (int i = 0; i < d; ++i) {
                    mem_sum_bits += size_ml_bits[i];
                    mem_sum_bits += c * n * raw_g[i] * std::log2(1 / t[i]);
                }
                for (int i = 0; i < d - 1; ++i) {
                    mem_sum_bits += c * n * branch_raw_g[i] * std::log2(1 / b[i]);
                }
                for (int i = 0; i < k; ++i) {
                    mem_sum_bits += c * n * final_raw_g[d - 1][i] * std::log2(1 / f[i]);
                }
            }

            float pos_query_time_us = 0.0f;
            float neg_query_time_us = 0.0f;
            {
                float prod_t = 1.0f;
                for (int i = 0; i < d; ++i) {
                    prod_t *= t[i];
                    pos_query_time_us += raw_g[i] * time_ml_us[i];
                    neg_query_time_us += raw_h[i] * prod_t * time_ml_us[i];
                }
            }

            return lambda_ * mem_sum_bits / mem_usage_bits_bf + (1 - lambda_) * (mu_ * pos_query_time_us + (1 - mu_) * neg_query_time_us) / time_per_item_us_bf;
        };

        // DP arrays
        std::vector<std::vector<float>> dp(D, std::vector<float>(P, std::numeric_limits<float>::infinity()));
        std::vector<std::vector<std::pair<int, int>>> pre_dp(D, std::vector<std::pair<int, int>>(P, {-1, -1}));

        // For each i and j + j_, pre-calculated the final layer sum.
        std::vector<std::vector<float>> final_layer_bits_sums(D, std::vector<float>(P, 0.0f));
        for (int i = 0; i < D; ++i) {
            for (int j_plus_j_ = 0; j_plus_j_ < P; ++j_plus_j_) {
                for (int l = 0; l < k; ++l) {
                    float fl_mem_bits = c * n * final_raw_g[i][l] * std::log2(1 / tilde_f(final_raw_g[i][l], final_raw_h[i][l] * std::pow(p_base, j_plus_j_)));
                    final_layer_bits_sums[i][j_plus_j_] += fl_mem_bits;
                }
            }
        }

        for (int i = D - 1; i >= 0; --i) {
            for (int j = 0; j < P; ++j) {
                float dp1 = std::numeric_limits<float>::infinity();
                int dp1_min_j_ = 0;

                for (int j_ = 0; j_ < P && j + j_ < P; ++j_) {
                    float tmp_final_layer_bits_sum = final_layer_bits_sums[i][j + j_];
                    float ti_mem_bits = c * n * raw_g[i] * std::log2(1 / std::pow(p_base, j_));
                    float mli_mem_bits = size_ml_bits[i];

                    float mem_sum_bits = ti_mem_bits + mli_mem_bits + tmp_final_layer_bits_sum;
                    float tmp_time_us = (mu_ * raw_g[i] + (1 - mu_) * raw_h[i] * std::pow(p_base, j + j_)) * time_ml_us[i];

                    float tmp_dp1 = lambda_ * mem_sum_bits / mem_usage_bits_bf + (1 - lambda_) * tmp_time_us / time_per_item_us_bf;

                    if (tmp_dp1 < dp1) {
                        dp1 = tmp_dp1;
                        dp1_min_j_ = j_;
                    }
                }
                if (i == D - 1) {
                    dp[i][j] = dp1;
                    pre_dp[i][j] = {1, dp1_min_j_};
                    continue;
                }
                float dp2 = std::numeric_limits<float>::infinity();
                int dp2_min_j_ = 0;

                for (int j_ = 0; j_ < P && j + j_ < P; ++j_) {
                    float ti_mem_bits = c * n * raw_g[i] * std::log2(1 / std::pow(p_base, j_));
                    float mli_mem_bits = size_ml_bits[i];
                    float bi_mem_bits = c * n * branch_raw_g[i] * std::log2(1 / tilde_f(branch_raw_g[i], branch_raw_h[i] * std::pow(p_base, j + j_)));

                    float mem_sum_bits = ti_mem_bits + mli_mem_bits + bi_mem_bits;
                    float tmp_time_us = (mu_ * raw_g[i] + (1 - mu_) * raw_h[i] * std::pow(p_base, j + j_)) * time_ml_us[i];

                    float tmp_dp2 = lambda_ * mem_sum_bits / mem_usage_bits_bf + (1 - lambda_) * tmp_time_us / time_per_item_us_bf + dp[i + 1][j + j_];

                    if (tmp_dp2 < dp2) {
                        dp2 = tmp_dp2;
                        dp2_min_j_ = j_;
                    }
                }

                if (dp1 < dp2) {
                    dp[i][j] = dp1;
                    pre_dp[i][j] = {1, dp1_min_j_};
                } else {
                    dp[i][j] = dp2;
                    pre_dp[i][j] = {2, dp2_min_j_};
                }
            }
        }

        // std::cout << "dp[0][0]: " << dp[0][0] << std::endl;

        std::vector<int> j_list;
        {
            int current_i = 0;
            int current_j = 0;
            while (true) {
                if (pre_dp[current_i][current_j].first == 1) {
                    j_list.push_back(pre_dp[current_i][current_j].second);
                    break;
                } else {
                    j_list.push_back(pre_dp[current_i][current_j].second);
                    current_j += pre_dp[current_i][current_j].second;
                    current_i += 1;
                }
            }
        }

        std::vector<float> t;
        for (int j : j_list) {
            t.push_back(std::pow(p_base, j));
        }
        int d = t.size();
        std::vector<float> th_f_ = th_f[d - 1];

        std::vector<float> branch_g = branch_raw_g;
        branch_g.resize(d - 1);

        std::vector<float> branch_h;
        for (int i = 0; i < d - 1; ++i) {
            branch_h.push_back(branch_raw_h[i] * std::pow(p_base, std::accumulate(j_list.begin(), j_list.begin() + i + 1, 0)));
        }

        std::vector<float> final_g = final_raw_g[d - 1];
        std::vector<float> final_h;
        for (int i = 0; i < k; ++i) {
            final_h.push_back(final_raw_h[d - 1][i] * std::pow(p_base, std::accumulate(j_list.begin(), j_list.end(), 0)));
        }

        std::vector<float> b(d - 1, F);
        std::vector<float> f(k, F);

        while (true) {
            int before_cnt = 0;
            float G_f1 = 0; // sum of g such that FPR = 1
            float H_f1 = 0; // sum of h such that FPR = 1
            for (int i = 0; i < d - 1; ++i) {
                if (b[i] == 1.0f) {
                    before_cnt += 1;
                    G_f1 += branch_g[i];
                    H_f1 += branch_h[i];
                }
            }
            for (int i = 0; i < k; ++i) {
                if (f[i] == 1.0f) {
                    before_cnt += 1;
                    G_f1 += final_g[i];
                    H_f1 += final_h[i];
                }
            }
            for (int i = 0; i < d - 1; ++i) {
                b[i] = tilde_f(branch_g[i], branch_h[i], G_f1, H_f1);
            }
            for (int i = 0; i < k; ++i) {
                f[i] = tilde_f(final_g[i], final_h[i], G_f1, H_f1);
            }

            int after_cnt = std::count(b.begin(), b.end(), 1.0f) + std::count(f.begin(), f.end(), 1.0f);
            if (before_cnt == after_cnt) {
                break;
            }
        }

        float objective = objective_func(d, b, t, f);
        // std::cout << "[INFO] Objective: " << objective << std::endl;

        float mem_sum = 0.0f;
        for (int i = 0; i < d; ++i) {
            mem_sum += size_ml_bits[i];
            mem_sum += c * n * raw_g[i] * std::log2(1 / t[i]);
        }
        for (int i = 0; i < d - 1; ++i) {
            mem_sum += c * n * branch_raw_g[i] * std::log2(1 / b[i]);
        }
        for (int i = 0; i < k; ++i) {
            mem_sum += c * n * final_raw_g[d - 1][i] * std::log2(1 / f[i]);
        }

        float expected_fpr = 0.0f;
        for (int i = 0; i < d - 1; ++i) {
            expected_fpr += branch_h[i] * b[i];
        }
        for (int i = 0; i < k; ++i) {
            expected_fpr += final_h[i] * f[i];
        }

        // std::cout << "[INFO] Expected FPR: " << expected_fpr << std::endl;
        // std::cout << "[INFO] Estimated Memory usage: " << mem_sum << " bits" << std::endl;

        if (expected_fpr > F * 1.01) {
            // std::cout << "[INFO] Min g/h of FPR = 1 Bloom Filters: " << std::numeric_limits<float>::infinity() << std::endl;
            // std::cout << "[INFO] Expected FPR: " << expected_fpr << std::endl;
            // std::cout << "[INFO] Estimated Memory usage: " << mem_sum / 8 / 1000 << " KB" << std::endl;
            // std::cout << "[INFO] d: " << d << std::endl;
            // std::cout << "[INFO] k: " << k << std::endl;
            // std::cout << "[INFO] n: " << n << std::endl;
            // std::cout << "[INFO] b: ";
            // for (float bi : b) {
            //     std::cout << bi << " ";
            // }
            // std::cout << std::endl;
            // std::cout << "[INFO] t: ";
            // for (float ti : t) {
            //     std::cout << ti << " ";
            // }
            // std::cout << std::endl;
            // std::cout << "[INFO] f: ";
            // for (float fi : f) {
            //     std::cout << fi << " ";
            // }
            // std::cout << std::endl;
            // std::cout << "[INFO] g_b: ";
            // for (float g_bi : branch_g) {
            //     std::cout << g_bi << " ";
            // }
            // std::cout << std::endl;
            // std::cout << "[INFO] g_t: ";
            // for (float g_ti : raw_g) {
            //     std::cout << g_ti << " ";
            // }
            // std::cout << std::endl;
            // std::cout << "[INFO] g_f: ";
            // for (float g_fi : final_raw_g[d - 1]) {
            //     std::cout << g_fi << " ";
            // }
            // std::cout << std::endl;
            throw std::runtime_error("Expected FPR does not match the target FPR.");
        }

        std::cout << "[INFO] Objective: " << objective << std::endl;

        if (objective < best_objective) {
            best_objective = objective;
            best_d = d;
            best_branch_g = branch_g;
            best_branch_h = branch_h;
            best_final_g = final_g;
            best_final_h = final_h;
            best_b = b;
            best_t = t;
            best_f = f;
            best_raw_g = raw_g;
            best_raw_h = raw_h;
            best_branch_raw_g = branch_raw_g;
            best_branch_raw_h = branch_raw_h;
            best_th_b = th_b;
            best_final_raw_g = final_raw_g;
            best_final_raw_h = final_raw_h;
            best_th_f = th_f;
            best_alpha = alpha;
        }
    }

    // The case where CLBF = a Bloom Filter
    {
        float mem_sum_bits = c * n * 1.0f * std::log2(1 / F);
        float pos_query_time_us = 0.0;
        float neg_query_time_us = 0.0;
        float objective = lambda_ * mem_sum_bits / mem_usage_bits_bf + (1 - lambda_) * (mu_ * pos_query_time_us + (1 - mu_) * neg_query_time_us) / time_per_item_us_bf;
        if (objective < best_objective) {
            best_d = 0;
        }
    }

    if (best_d == 0) {
        std::cout << "[INFO] CLBF is a Bloom Filter." << std::endl;
        int d = 0;
        int k = 0;
        std::vector<float> b = {};
        std::vector<float> t = {F};
        std::vector<float> f = {};
        std::vector<float> th_b_ = {};
        std::vector<float> th_f_ = {};
        std::vector<float> branch_g = {};
        std::vector<float> trunk_g = {};
        std::vector<float> final_g = {};
        // for memo
        std::vector<float> branch_h = {};
        std::vector<float> trunk_h = {};
        std::vector<float> final_h = {};
        float alpha = 0.0f;
        BurgerConfig optimal_parameters = {
            d, k, n, b, t, f, th_b_, th_f_, branch_g, trunk_g, final_g,
            branch_h, trunk_h, final_h, alpha
        };
        return optimal_parameters;
    }

    std::vector<float> raw_g = best_raw_g;
    std::vector<float> raw_h = best_raw_h;
    std::vector<float> branch_raw_g = best_branch_raw_g;
    std::vector<float> th_b = best_th_b;
    std::vector<std::vector<float>> final_raw_g = best_final_raw_g;
    std::vector<std::vector<float>> final_raw_h = best_final_raw_h;
    std::vector<std::vector<float>> th_f = best_th_f;
    int d = best_d;
    std::vector<float> trunk_g(raw_g.begin(), raw_g.begin() + d);
    std::vector<float> trunk_h;
    {
        float prod_t = 1.0f;
        for (int i = 0; i < d; ++i) {
            trunk_h.push_back(raw_h[i] * prod_t);
            prod_t *= best_t[i];
        }
    }
    std::vector<float> branch_g = best_branch_g;
    std::vector<float> branch_h = best_branch_h;
    std::vector<float> final_g = best_final_g;
    std::vector<float> final_h = best_final_h;
    std::vector<float> th_b_(th_b.begin(), th_b.begin() + d - 1);
    std::vector<float> th_f_ = th_f[d - 1];
    std::vector<float> b = best_b;
    std::vector<float> t = best_t;
    std::vector<float> f = best_f;

    float mem_sum = 0.0f;
    for (int i = 0; i < d; ++i) {
        mem_sum += size_ml_bits[i];
        mem_sum += c * n * trunk_g[i] * std::log2(1 / t[i]);
    }
    for (int i = 0; i < d - 1; ++i) {
        mem_sum += c * n * branch_g[i] * std::log2(1 / b[i]);
    }
    for (int i = 0; i < k; ++i) {
        mem_sum += c * n * final_g[i] * std::log2(1 / f[i]);
    }

    float expected_fpr = 0.0f;
    for (int i = 0; i < d - 1; ++i) {
        expected_fpr += branch_h[i] * b[i];
    }
    for (int i = 0; i < k; ++i) {
        expected_fpr += final_h[i] * f[i];
    }

    std::cout << "[INFO] Expected FPR: " << expected_fpr << std::endl;
    std::cout << "[INFO] Estimated Memory usage: " << mem_sum / 8 / 1000 << " KB" << std::endl;

    BurgerConfig optimal_parameters = {
        d, k, n, b, t, f, th_b_, th_f_, branch_g, trunk_g, final_g,
        branch_h, trunk_h, final_h, best_alpha
    };

    if (!std::is_sorted(th_f_.begin(), th_f_.end())) {
        throw std::runtime_error("th_f_ is not sorted.");
    }

    return optimal_parameters;
}


#endif // FIND_OPTIMAL_PARAMETERS_H
