#include "../../src/utils/utils.h"
#include "habf/habf.h"
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <random>
#include <chrono>
#include <algorithm>
#include <sstream>
#include <cstdio>
#include <iomanip>
#include <unistd.h>

void construct_and_test_habf(
    const std::vector<std::string>& all_pos_key_str,
    const std::vector<std::string>& X_train_key,
    const std::vector<std::vector<float>>& y_train,
    const std::vector<std::string>& X_val_key,
    const std::vector<std::vector<float>>& y_val,
    float bits_per_key,
    const std::vector<std::string>& X_test_key,
    const std::vector<std::vector<float>>& y_test,
    std::string habf_folder_path,
    std::string result_folder_path,
    float pos_query_ratio,
    int query_num
) {
    std::string result_json_path = result_folder_path + "/result.json";

    ////////////////////
    // Construct HABF //
    ////////////////////

    std::string habf_model_path = habf_folder_path + "/habf_model.bin";
    std::string habf_info_path = habf_folder_path + "/info.json";

    size_t pos_count = all_pos_key_str.size();
    habf::HABFilter habf(bits_per_key, pos_count);
    float add_keys_time_ms = 0;
    float model_construct_time_ms = 0;
    float fpr_on_val = 0.0;
    {
        auto construction_start = std::chrono::high_resolution_clock::now();

        // prepare positive and negative samples
        auto prepare_start = std::chrono::high_resolution_clock::now();
        std::vector<Slice *> pos_keys_;
        std::vector<Slice *> neg_keys_;
        for (int i = 0; i < (int)all_pos_key_str.size(); i++) {
            Slice *s = new Slice();
            s->str = all_pos_key_str[i];
            s->cost = 1.0;
            pos_keys_.push_back(s);
        }
        for (int i = 0; i < (int)X_train_key.size(); i++) {
            if (y_train[i][0] != 1) {
                Slice *s = new Slice();
                s->str = X_train_key[i];
                s->cost = 1.0;
                neg_keys_.push_back(s);
            }
        }
        for (int i = 0; i < (int)X_val_key.size(); i++) {
            if (y_val[i][0] != 1) {
                Slice *s = new Slice();
                s->str = X_val_key[i];
                s->cost = 1.0;
                neg_keys_.push_back(s);
            }
        }
        auto prepare_end = std::chrono::high_resolution_clock::now();

        // add keys
        auto add_keys_start = std::chrono::high_resolution_clock::now();
        habf.AddAndOptimize(pos_keys_, neg_keys_);
        auto add_keys_end = std::chrono::high_resolution_clock::now();
        auto construction_end = std::chrono::high_resolution_clock::now();

        model_construct_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(construction_end - construction_start).count();
        add_keys_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(add_keys_end - add_keys_start).count();

        // Check FNR == 0
        for (size_t i = 0; i < pos_keys_.size(); i++) {
            assert(habf.Contain(*pos_keys_[i]));
        }
        // Check FPR for negative samples
        {
            double fpr_c = 0;
            double total_ = 0.0;
            for (size_t i = 0; i < neg_keys_.size(); i++) {
                if (habf.Contain(*neg_keys_[i])) {
                    fpr_c += 1;
                }
                total_ += 1;
            }
            fpr_on_val = fpr_c / total_;
            std::cerr << "FPR on validation data: " << fpr_on_val << "\n";
        }
    }

    float habf_size_bits = (int)all_pos_key_str.size() * bits_per_key;
    float habf_size_kb = habf_size_bits / 8 / 1024;

    // save HABF model
    // habf.Save(habf_model_path);

    // save info.json
    {
        std::ofstream ofs(habf_info_path);
        ofs << "{\n";
        ofs << "  \"model_type\": \"habf\",\n";
        ofs << "  \"bits_per_key\": " << bits_per_key << ",\n";
        ofs << "  \"model_size_kb\": " << habf_size_kb << ",\n";
        ofs << "  \"add_keys_time_ms\": " << add_keys_time_ms << ",\n";
        ofs << "  \"model_construct_time_ms\": " << model_construct_time_ms << "\n";
        ofs << "}\n";
        ofs.close();
    }

    ///////////////
    // Test HABF //
    ///////////////
    // prepare test data
    std::vector<std::vector<float>> sampled_y_test;
    std::vector<Slice *> sampled_X_test_key;
    {
        int pos_query_num = pos_query_ratio * query_num;
        int neg_query_num = query_num - pos_query_num;

        std::vector<Slice *> X_test_key_pos;
        std::vector<Slice *> X_test_key_neg;

        for (int i = 0; i < (int)y_test.size(); i++) {
            if (y_test[i][0] == 1) {
                Slice *s = new Slice();
                s->str = X_test_key[i];
                s->cost = 1.0;
                X_test_key_pos.push_back(s);
            } else {
                Slice *s = new Slice();
                s->str = X_test_key[i];
                s->cost = 1.0;
                X_test_key_neg.push_back(s);
            }
        }

        if (pos_query_num > (int)X_test_key_pos.size() || neg_query_num > (int)X_test_key_neg.size()) {
            std::cerr << "Error: Not enough positive or negative queries in the test data.\n";
            return;
        }

        // randomly sample positive and negative queries (duplicate not allowed)
        std::mt19937 g(42);
        std::vector<int> pos_indices(X_test_key_pos.size());
        std::iota(pos_indices.begin(), pos_indices.end(), 0);
        std::shuffle(pos_indices.begin(), pos_indices.end(), g);
        std::vector<int> neg_indices(X_test_key_neg.size());
        std::iota(neg_indices.begin(), neg_indices.end(), 0);
        std::shuffle(neg_indices.begin(), neg_indices.end(), g);
        for (int i = 0; i < pos_query_num; i++) {
            sampled_y_test.push_back({1});
            sampled_X_test_key.push_back(X_test_key_pos[pos_indices[i]]);
        }
        for (int i = 0; i < neg_query_num; i++) {
            sampled_y_test.push_back({0});
            sampled_X_test_key.push_back(X_test_key_neg[neg_indices[i]]);
        }

        // shuffle the sampled data
        std::vector<int> indices(sampled_X_test_key.size());
        std::iota(indices.begin(), indices.end(), 0);
        std::shuffle(indices.begin(), indices.end(), g);
        std::vector<std::vector<float>> shuffled_sampled_y_test(sampled_y_test.size());
        std::vector<Slice *> shuffled_sampled_X_test_key(sampled_X_test_key.size());

        for (int i = 0; i < (int)indices.size(); i++) {
            shuffled_sampled_y_test[i] = sampled_y_test[indices[i]];
            shuffled_sampled_X_test_key[i] = sampled_X_test_key[indices[i]];
        }

        sampled_y_test = shuffled_sampled_y_test;
        sampled_X_test_key = shuffled_sampled_X_test_key;
    }

    // warm-up
    const int warmup_num = 10;
    const int test_num = 10;
    for (int i = 0; i < warmup_num; i++) {
        for (size_t i = 0; i < sampled_X_test_key.size(); i++) {
            bool res = habf.Contain(*sampled_X_test_key[i]);
        }
    }

    // test
    int query_pos_num = 0;
    int query_neg_num = 0;
    int false_positives = 0;
    int false_negatives = 0;
    float fpr = 1.0;
    float fnr = 1.0;
    float total_test_time = 0.0;
    std::vector<float> test_time_list;
    float test_time = 0.0;
    for (int i = 0; i < test_num; i++) {
        auto start = std::chrono::high_resolution_clock::now();
        std::vector<bool> predictions(sampled_X_test_key.size());
        for (size_t i = 0; i < sampled_X_test_key.size(); i++) {
            predictions[i] = habf.Contain(*sampled_X_test_key[i]);
        }
        auto end = std::chrono::high_resolution_clock::now();

        fpr = calculate_false_positive_rate(predictions, sampled_y_test);
        fnr = calculate_false_negative_rate(predictions, sampled_y_test);
        query_pos_num = std::count(sampled_y_test.begin(), sampled_y_test.end(), std::vector<float>({1}));
        query_neg_num = std::count(sampled_y_test.begin(), sampled_y_test.end(), std::vector<float>({0}));
        false_positives = (int)(fpr * query_neg_num + 0.5);
        false_negatives = (int)(fnr * query_pos_num + 0.5);
        total_test_time += (float)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000;
        test_time_list.push_back((float)std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000);
    }
    test_time = total_test_time / test_num;
    std::cout << "Total test time: " << total_test_time << " ms\n";
    std::cout << "Average test time: " << test_time << " ms\n";
    std::cout << "Test time per key: " << test_time / query_num * 1000000 << " ns\n";

    // save to result.json
    {
        auto vector_to_json_float = [](const std::vector<float>& v) {
            std::ostringstream s;
            s << "[";
            for (int i = 0; i < (int)v.size(); i++) {
                if (std::isnan(v[i])) {
                    s << "null";
                } else if (std::isinf(v[i])) {
                    s << (v[i] > 0 ? "Infinity" : "-Infinity");
                } else {
                    s << std::setprecision(20) << v[i];
                }
                if (i < (int)v.size() - 1) {
                    s << ", ";
                }
            }
            s << "]";
            return s.str();
        };

        std::ofstream ofs(result_json_path);
        ofs << "{\n";
        ofs << "  \"bf_model_folder_path\": \"" << habf_folder_path << "\",\n";
        ofs << "  \"pos_query_ratio\": " << pos_query_ratio << ",\n";
        ofs << "  \"query_num\": " << query_num << ",\n";
        ofs << "  \"fpr\": " << fpr << ",\n";
        ofs << "  \"fpr_on_val\": " << fpr_on_val << ",\n";
        ofs << "  \"fnr\": " << fnr << ",\n";
        ofs << "  \"query_pos_num\": " << query_pos_num << ",\n";
        ofs << "  \"query_neg_num\": " << query_neg_num << ",\n";
        ofs << "  \"false_positives\": " << false_positives << ",\n";
        ofs << "  \"false_negatives\": " << false_negatives << ",\n";
        ofs << "  \"test_time_ms\": " << test_time << ",\n";
        ofs << "  \"test_time_list_ms\": " << vector_to_json_float(test_time_list) << "\n";
        ofs << "}\n";
        ofs.close();
    }
}

int main(int argc, char* argv[]) {
    // for constructing habf
    std::string all_pos_key_path;
    std::string X_train_key_path;
    std::string y_train_path;
    std::string X_val_key_path;
    std::string y_val_path;
    float bits_per_key = -1.0;
    // for test habf
    std::string X_test_key_path;
    std::string y_test_path;
    std::string model_folder_path;
    std::string result_folder_path;
    float pos_query_ratio = -1.0;
    int query_num = -1;

    int opt;
    while ((opt = getopt(argc, argv, "p:k:y:v:w:b:t:u:o:m:r:q:")) != -1) {
        switch (opt) {
            case 'p':
                all_pos_key_path = optarg;
                break;
            case 'k':
                X_train_key_path = optarg;
                break;
            case 'y':
                y_train_path = optarg;
                break;
            case 'v':
                X_val_key_path = optarg;
                break;
            case 'w':
                y_val_path = optarg;
                break;
            case 'b':
                bits_per_key = std::stof(optarg);
                break;
            case 't':
                X_test_key_path = optarg;
                break;
            case 'u':
                y_test_path = optarg;
                break;
            case 'o':
                result_folder_path = optarg;
                break;
            case 'm':
                model_folder_path = optarg;
                break;
            case 'r':
                pos_query_ratio = std::stof(optarg);
                break;
            case 'q':
                query_num = std::atoi(optarg);
                break;
            default:
                std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -k <X_val_key_path> -y <y_val_path> -b <bits_per_key> -t <X_test_key_path> -u <y_test_path> -o <result_folder_path> -m <model_folder_path> -r <pos_query_ratio> -q <query_num>\n";
                return 1;
        }
    }

    // Check if all required arguments are provided
    if (all_pos_key_path.empty() || X_val_key_path.empty() || y_val_path.empty() || bits_per_key < 0.0 || X_test_key_path.empty() || y_test_path.empty() || model_folder_path.empty() || result_folder_path.empty() || pos_query_ratio < 0.0 || query_num < 0) {
        std::cerr << "Usage: " << argv[0] << " -p <all_pos_key_path> -k <X_val_key_path> -y <y_val_path> -b <bits_per_key> -t <X_test_key_path> -u <y_test_path> -o <result_folder_path> -m <model_folder_path> -r <pos_query_ratio> -q <query_num>\n";
        return 1;
    }

    // If result_json_path already exists, skip testing the PLBF
    std::string result_json_path = result_folder_path + "/result.json";
    if (std::filesystem::exists(result_json_path)) {
        std::cout << "Result JSON already exists at " << result_json_path << ". Skipping testing the HABF.\n";
        return 0;
    }

    // Ensure the processing code is executed
    std::vector<std::string> all_pos_key = load_key(all_pos_key_path);
    std::vector<std::string> X_train_key = load_key(X_train_key_path);
    std::vector<std::vector<float>> y_train = load_data(y_train_path);
    std::vector<std::string> X_val_key = load_key(X_val_key_path);
    std::vector<std::vector<float>> y_val = load_data(y_val_path);
    std::vector<std::string> X_test_key = load_key(X_test_key_path);
    std::vector<std::vector<float>> y_test = load_data(y_test_path);

    construct_and_test_habf(all_pos_key, X_train_key, y_train, X_val_key, y_val, bits_per_key, X_test_key, y_test, model_folder_path, result_folder_path, pos_query_ratio, query_num);

    return 0;
}
