#include "emp-tool/emp-tool.h"
#include "2pc-backend/emp-dpf.h"
#include <iostream>

using namespace std;
using namespace emp;
static constexpr int matrix_block_size = 768 * 768 * 32 / 128;
using Bit = Bit_T<GbWire>;
using Integer = Integer_T<GbWire>;

const int depth = 20;

MPCBackend<NetIO> *base;

void init_backend(NetIO **ios, int party, int depth, int size_of_element, int threads = 2, int round_key_size = 94208, const char *addr = "127.0.0.1")
{

    block delta;
    PRG().random_block(&delta, 1);
    delta ^= makeBlock(0, (getSLSB(delta) << 1) | getLSB(delta));
    if (party == ALICE)
        delta ^= makeBlock(0, 1);
    else
        delta ^= makeBlock(0, 3);

    base = new MPCBackend<NetIO>(party, threads, ios, delta, depth);
    emp::backend = base;
}

void bit_decomp(bool *out, int input, int size)
{

    for (int i = 0; i < size; ++i)
    {
        out[i] = (input >> i) & 1;
    }

    // print bits
    // cout << "bits: ";
    // for (int i = 0; i < size; ++i)
    // {
    //     cout << out[i] << " ";
    // }
    // cout << endl;
}
void l2_distance(vector<Integer> &a, vector<Integer> &b, Integer &res)
{
    Integer tmp;
    for (int i = 0; i < a.size(); ++i)
    {
        tmp = a[i] - b[i];
        res = res + tmp * tmp;
    }
}

double test_fuzzy_matching(vector<vector<Integer>> &fuzzy_db, vector<Integer> &query, int db_size = 100, int entry_len = 77 * 768)
{
    cout << "==================test_fuzzy_matching==================" << endl;

    auto t1 = clock_start();

    vector<Integer> ret(entry_len);
    // find the index of the closest vector to query
    Integer min_dist = Integer(32, INT_MAX, PUBLIC);
    Integer min_idx = Integer(32, -1, PUBLIC);
    for (int i = 0; i < db_size; ++i)
    {
        Integer dist = Integer(32, 0, PUBLIC);
        l2_distance(fuzzy_db[i], query, dist);
        Bit cond = dist < min_dist;                 // compute once
        min_idx = min_idx.If(cond, Integer(32, i)); // update index first
        min_dist = min_dist.If(cond, dist);         // then update distance
    }
    double fuzzy_matching_time = time_from(t1) / 1e3 / 1e3;
    cout << "Fuzzy matching time: " << fuzzy_matching_time << "s" << endl;

    cout << "min dist: " << min_dist.reveal<int32_t>(PUBLIC) << endl;
    cout << "min idx: " << min_idx.reveal<int32_t>(PUBLIC) << endl;
    // cout << "query idx: " << query_idx << endl;
    cout << "Database size: " << db_size << endl;
    cout << "Entry length: " << entry_len << endl;
    return fuzzy_matching_time;
}

void test_ot_hash(vector<vector<Integer>> &fuzzy_db, vector<Integer> &query, int size, int fuzzy_entry_size, int party, vector<block *> &db, NetIO **ios, int threads)
// Initialize OT
{
    auto pre_fuzzy_comm = ios[0]->counter;
    double fuzzy_time = test_fuzzy_matching(fuzzy_db, query, size, fuzzy_entry_size);
    double fuzzy_comm = (ios[0]->counter - pre_fuzzy_comm) / 1000.0 / 1000.0;
    cout << "fuzzy comm: " << fuzzy_comm << "MB\n";
    auto pre_comm = ios[0]->counter;
    cout << "==================test_ot_Hash==================" << endl;
    cout << "Size: " << size << endl;
    cout << "Party: " << party << endl;
    auto single_entry_size = matrix_block_size * sizeof(block) / 1000.0 / 1000.0;
    cout << "single entry size: " << single_entry_size << "MB" << endl;
    int log_size = ceil(log2(size)); // log2(size)

    // Define the size of each message block

    if (party == ALICE) // Receiver
    {
        // FerretCOT<NetIO> *ot_recv = new FerretCOT<NetIO>(BOB, threads, ios, false, true, ferret_b11);
        IKNP<NetIO> *ot_recv = new IKNP<NetIO>(ios[0], false);
        int idx = 9; // Index Alice wants to retrieve
        if (idx >= size)
        {
            throw std::invalid_argument("Alice Index out of bounds");
            return;
        }

        // Decompose index into bits
        bool idx_bits[log_size];
        bit_decomp(idx_bits, idx, log_size);

        // Receive enc database
        auto t1 = clock_start();
        // Receive encrypted database from Bob
        vector<block *> enc_db(size);
        for (int i = 0; i < size; ++i)
        {
            enc_db[i] = new block[matrix_block_size];
        }
        for (int i = 0; i < size; ++i)
        {
            ios[0]->recv_block(enc_db[i], matrix_block_size);
        }
        auto t_recv_db = time_from(t1) / 1e3 / 1e3;
        cout << "Received encrypted database: " << t_recv_db << "s" << endl;
        // Receive keys via OT
        t1 = clock_start();
        vector<block> received_keys(log_size);

        ot_recv->recv(received_keys.data(), idx_bits, log_size); // Receive key corresponding to bit value
        auto t_recv_keys = time_from(t1) / 1e3 / 1e3;
        cout << "Received keys (OT): " << t_recv_keys << "s" << endl;

        t1 = clock_start();
        // Decrypt the desired entry
        // Decrypt with each key
        Hash hash;
        block mask[matrix_block_size];
        block digest_key = hash.hash_for_block(received_keys.data(), log_size);

        PRG(&digest_key).random_block(mask, matrix_block_size);
        xorBlocks_arr(enc_db[idx], enc_db[idx], mask, matrix_block_size);
        // delete mask
        auto t_decrypt = time_from(t1) / 1e3 / 1e3;
        cout << "Decryption time: " << t_decrypt << "s" << endl;
        cout << "Decrypted entry at " << idx << ": " << enc_db[idx][0] << endl;
        // Cleanup
        for (int i = 0; i < size; ++i)
        {
            delete[] enc_db[i];
        }
        delete ot_recv;
        auto t_recv = t_recv_db + t_recv_keys + t_decrypt;
        double ot_comm, db_comm = 0;
        ios[0]->recv_data(&ot_comm, sizeof(double));
        ios[0]->recv_data(&db_comm, sizeof(double));
        cout << "Receiver time: " << t_recv << "s" << endl;
        cout << "Receiver OT comm:" << ot_comm << "MB\n";
        cout << "Receiver db comm:" << db_comm << "MB\n";
        cout << "size, entry_size, total_time, recv_db_time, recv_keys_time, decrypt_time, ot_comm, db_comm, fuzzy_comm, fuzzy_time\n";
        cout << size << "," << single_entry_size << "MB," << t_recv << ","
             << t_recv_db << "," << t_recv_keys << "," << t_decrypt << "," << ot_comm << "MB," << db_comm << "MB," << fuzzy_comm << "," << fuzzy_time << "\n";
    }
    else // BOB - Sender
    {
        auto pre_ot_comm = ios[0]->counter;
        block delta;
        PRG().random_block(&delta, 1);
        // FerretCOT<NetIO> *ot_send = new FerretCOT<NetIO>(ALICE, threads, ios, false, true, ferret_b11);
        IKNP<NetIO> *ot_send = new IKNP<NetIO>(ios[0], false);

        double ot_comm = (ios[0]->counter - pre_ot_comm) / 1000.0 / 1000.0;

        // Generate 2 * log_size AES keys (two keys for each bit position)
        vector<block> prg_keys0(log_size);
        vector<block> prg_keys1(log_size);
        PRG().random_block(prg_keys0.data(), prg_keys0.size());
        PRG().random_block(prg_keys1.data(), prg_keys1.size());

        auto t1 = clock_start();

        // Create and encrypt database
        Hash hash;

        vector<block *> mask_arr(size);
        for (int i = 0; i < size; ++i)
        {
            mask_arr[i] = new block[matrix_block_size];
        }
        for (int i = 0; i < size; ++i)
        {
            block concate_key[log_size];
            for (int j = 0; j < log_size; ++j)
            {
                bool bit_val = (i >> j) & 1;
                concate_key[j] = (bit_val ? prg_keys1[j] : prg_keys0[j]);
            }
            block digest_key = hash.hash_for_block(concate_key, log_size);

            PRG(&digest_key).random_block(&mask_arr[i][0], matrix_block_size);
            xorBlocks_arr(db[i], db[i], mask_arr[i], matrix_block_size);
        }

        for (int i = 0; i < size; i++)
        {
            delete[] mask_arr[i];
        }
        auto t_encrypt = time_from(t1) / 1e3 / 1e3;
        cout << "Database encryption time: " << t_encrypt << "s" << endl;

        t1 = clock_start();
        auto pre_db_comm = ios[0]->counter;
        for (int i = 0; i < size; i++)
        {
            // Send encrypted entry to Alice
            ios[0]->send_block(db[i], matrix_block_size);
        }
        ios[0]->flush();
        double db_comm = (ios[0]->counter - pre_db_comm) / 1000.0 / 1000.0;
        auto t_send_db = time_from(t1) / 1e3 / 1e3;
        cout << "sending database: " << t_send_db << "s" << endl;

        t1 = clock_start();
        // Send keys via OT
        pre_ot_comm = ios[0]->counter;
        ot_send->send(prg_keys0.data(), prg_keys1.data(), log_size);
        ot_comm += (ios[0]->counter - pre_ot_comm) / 1000.0 / 1000.0;

        auto t_send_keys = time_from(t1) / 1e3 / 1e3;
        cout << "sending keys (OT): " << t_send_keys << "s" << endl;
        auto t_send = t_send_db + t_send_keys + t_encrypt;
        cout << "Sender time: " << t_send << "s" << endl;
        double send_comm = (ios[0]->counter - pre_comm) / 1000.0 / 1000.0;
        cout << "Total comm:" << send_comm << "MB\n";
        cout << "OT comm: " << ot_comm << "MB\n";
        cout << "db comm: " << db_comm << "MB\n";
        // send the comm to Alice for ease of
        ios[0]->send_data(&ot_comm, sizeof(double));
        ios[0]->send_data(&db_comm, sizeof(double));

        // Cleanup
        for (int i = 0; i < size; ++i)
        {
            delete[] db[i];
        }
        delete ot_send;
        cout << "size, entry_size, total_time, encrypt_time, send_db_time, send_keys_time, ot_comm, db_comm\n";
        cout << size << "," << single_entry_size << " MB," << t_send << ","
             << t_encrypt << "," << t_send_db << "," << t_send_keys << "," << ot_comm << "MB," << db_comm << "MB\n";
    }
}

int main(int argc, char **argv)
{
    int party = atoi(argv[1]);
    int size = 100;
    int entry_len = 768 * 32;
    if (argc < 2)
    {
        cout << "Usage: " << argv[0] << " <party> <size>" << endl;
        return 1;
    }
    if (argc < 3)
    {
        cout << "Using default size: " << size << endl;
        cout << "Using default entry length: " << entry_len << endl;
    }
    else if (argc == 3)
    {
        size = atoi(argv[2]);
    }
    else if (argc == 4)
    {
        size = atoi(argv[2]);
        entry_len = atoi(argv[3]);
    }

    int threads = 2;
    vector<block *> db;

    if (party == BOB)
    {
        db.resize(size);
        // Initialize each entry with a distinct value
        for (int i = 0; i < size; ++i)
        {
            db[i] = new block[matrix_block_size];
            for (int j = 0; j < matrix_block_size; ++j)
            {
                db[i][j] = makeBlock(0, i); // Value i at position i
            }
        }
    }
    NetIO **ios = new NetIO *[threads];
    for (int i = 0; i < max(1, threads); i++)
    {
        ios[i] = new NetIO((party - 1) ? "127.0.0.1" : nullptr, 12345 + i, true);
    }
    init_backend(ios, party, depth, 32, threads);

    int query_idx = 5;
    vector<Integer> query(entry_len);
    for (int i = 0; i < entry_len; ++i)
    {
        query[i] = Integer(32, query_idx, PUBLIC);
    }
    base->switch_to_gt();
    vector<vector<Integer>> fuzzy_db(size);
    for (int i = 0; i < size; ++i)
    {
        fuzzy_db[i] = vector<Integer>(entry_len);
        for (int j = 0; j < entry_len; ++j)
        {
            fuzzy_db[i][j] = Integer(32, i, PUBLIC);
        }
    }
    test_ot_hash(fuzzy_db, query, size, entry_len, party, db, ios, threads);


    for (int i = 0; i < threads; i++)
    {
        delete ios[i];
    }
    return 0;
}
