#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include "source/sample.cpp"
#include "source/utils.cpp"
#include <string>

using namespace emp;
using namespace std;

int port, party;
const int threads = 12;


void test_parse_node_file() {
    string fn = "noise_presets/test_node_names.txt";
    size_t tree_depth;
    vector<int> xs = parse_node_file(fn, tree_depth);
    cout << "[";
    size_t sz = xs.size();
    for (int i=0; i<sz; ++i) {
        cout << " " << xs[i] << " ";
    }
    cout << "]   depth: " << tree_depth << endl;
}

void test_parse_leaf_files() {
    string name_file = "noise_presets/test_node_names.txt";
    string val_file = "noise_presets/test_leaf_elements.txt";
    unordered_map<int, int> m = parse_leaf_files(name_file, val_file);
    cout << "{ ";
    for (auto i : m) {
        cout << "(" << i.first << ", " << i.second << ") ";
    }
    cout << " }" << endl;
}


void test_circuit_zk(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    Integer a(32, 3, ALICE);
    Integer b(32, 2, ALICE);
    cout << (a - b).reveal<uint32_t>(PUBLIC) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_random_challenge(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *io = ios[0];
    
    PRG prg;
    block r;
    prg.random_block(&r, 1);
    bool r_in_bits[128];
    block_to_bool(r_in_bits, r);
    Bit a(r_in_bits[0], ALICE);
    bool s_in_bits[128];
    
    if (party == ALICE) {
        block s;
        io->recv_block(&s, 1);
        block_to_bool(s_in_bits, s);
    } else {
        block s;
        PRG prg;
        prg.random_block(&s, 1);
        io->send_block(&s, 1);
        io->flush();
        block_to_bool(s_in_bits, s);
    }
    
    Bit b(s_in_bits[0], PUBLIC);
    Bit c = a ^ b;
    cout << "random coin: " << c.reveal() << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}


void test_seed_gen(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *io = ios[0];
    
    bool a_bits[128];
    bool b_bits[128];
    vector<Bit> r;
    for (int i=0; i<128; ++i) {
        a_bits[i] = 0;
        b_bits[i] = 1;
        r.push_back(Bit(0, PUBLIC));
        if (i%4==0) {
            a_bits[i] = 1;
        }
    }

    interactive_seed_generation(io, party, a_bits, b_bits, 128, r);

    print_Bit_vec(r);
    
    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}




// ##############################################################################
// ##################### tree sampler unit tests ################################
// ##############################################################################



// --------------------- test 01, full binary tree 3 layers ---------------------

// now in utils
/*
vector<int> ex_node_names_01() {
    vector<int> nn;
    for (int i=1; i<=7; ++i) {
        nn.push_back(i);
    }
    return nn;
}

unordered_map<int, int> ex_leaf_vals_01() {
    unordered_map<int, int> lv = {
        {4, 2*4},
        {5, 2*5},
        {6, 2*6},
        {7, 2*7}
    };
    return lv;
}
*/

// full binary tree with 7 nodes
unordered_map<int, int> ex_map_01() {
    unordered_map<int, int> m = {
        {1, 1},
        {2, 2},
        {3, 3},
        {4, 4},
        {5, 5},
        {6, 6},
        {7, 7}
    };
    return m;
}

// select sequence for the tree from ex_map_01()
// internal nodes in topological order from root, reversed
// --> 3, 2, 1
vector<int> ex_select_sequence_01() {
    vector<int> select_sequence;
    select_sequence.push_back(3);
    select_sequence.push_back(2);
    select_sequence.push_back(1);
    return select_sequence;
}

vector<Integer> ex_val_01() {
    vector<Integer> val;
    // initialize all values in val to -1
    for (int i=0; i<8; ++i) {
        val.push_back(Integer(32, -1, ALICE));
    }
    // set leaf values to 2*4 through 2*7
    for (int i=4; i<=7; ++i) {
        val[i] = Integer(32, 2*i, ALICE);
    }
    return val;
}

// 2 bits -- 01
vector<Bit> ex_r_01() {
    vector<Bit> r;
    r.push_back(Bit(0, ALICE));
    r.push_back(Bit(1, ALICE));
    return r;
}

vector<Bit> two_bit_vec(bool a, bool b) {
    vector<Bit> r;
    r.push_back(Bit(a, ALICE));
    r.push_back(Bit(b, ALICE));
    return r;
}

void ts_init_unit_01(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    vector<int> node_names = ex_node_names_01();
    unordered_map<int, int> leaf_vals = ex_leaf_vals_01();

    vector<Integer> val;
    vector<int> select_sequence;
    unordered_map<int, int> m;
    ts_init(node_names, leaf_vals, val, select_sequence, m);

    vector<Integer> check_val = ex_val_01();
    cout << "vals equal?: " << is_equal(val, check_val) << endl;

    vector<int> check_ss = ex_select_sequence_01();
    cout << "select_sequences equal?: " << equal(select_sequence.begin(), select_sequence.begin() + select_sequence.size(), check_ss.begin()) << endl;

    unordered_map<int, int> check_m = ex_map_01();
    cout << "ms equal?: " << (m == check_m) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void tree_sample_unit_01(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    vector<Integer> val = ex_val_01();
    vector<Bit> r = ex_r_01();
    vector<int> select_sequence = ex_select_sequence_01();
    size_t num_selects = 3;
    unordered_map<int, int> m = ex_map_01();

    r = two_bit_vec(0,0);
    Integer out = tree_sample(val, r, select_sequence, num_selects, m);
    cout << "tree_sample_unit_01 result (should be 8): " << out.reveal<int>() << endl;

    r = two_bit_vec(0,1);
    out = tree_sample(val, r, select_sequence, num_selects, m);
    cout << "tree_sample_unit_01 result (should be 10): " << out.reveal<int>() << endl;

    r = two_bit_vec(1,0);
    out = tree_sample(val, r, select_sequence, num_selects, m);
    cout << "tree_sample_unit_01 result (should be 12): " << out.reveal<int>() << endl;

    r = two_bit_vec(1,1);
    out = tree_sample(val, r, select_sequence, num_selects, m);
    cout << "tree_sample_unit_01 result (should be 14): " << out.reveal<int>() << endl;


    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}





// --------------------- test 02, 3 layers w/ nodes 4 & 5 missing ---------------------

vector<int> ex_node_names_02() {
    vector<int> nn = {1, 2, 3, 6, 7};
    return nn;
}

unordered_map<int, int> ex_leaf_vals_02() {
    unordered_map<int, int> lv = {
        {2, 2*2},
        {6, 2*6},
        {7, 2*7}
    };
    return lv;
}

// full binary tree with 7 nodes
unordered_map<int, int> ex_map_02() {
    unordered_map<int, int> m = {
        {1, 1},
        {2, 2},
        {3, 3},
        {6, 4},
        {7, 5}
    };
    return m;
}

// select sequence for the tree from ex_map_01()
// internal nodes in topological order from root, reversed
// --> 3, 1
vector<int> ex_select_sequence_02() {
    vector<int> select_sequence;
    select_sequence.push_back(3);
    select_sequence.push_back(1);
    return select_sequence;
}

vector<Integer> ex_val_02() {
    vector<Integer> val;
    // initialize all values in val to -1
    for (int i=0; i<6; ++i) {
        val.push_back(Integer(32, -1, ALICE));
    }
    // set leaf values
    val[2] = Integer(32, 2*2, ALICE);
    val[6] = Integer(32, 6*2, ALICE);
    val[7] = Integer(32, 7*2, ALICE);
    return val;
}

// 2 bits -- 01
vector<Bit> ex_r_02() {
    vector<Bit> r;
    r.push_back(Bit(0, ALICE));
    r.push_back(Bit(1, ALICE));
    return r;
}

void ts_init_unit_02(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    vector<int> node_names = ex_node_names_02();
    unordered_map<int, int> leaf_vals = ex_leaf_vals_02();

    vector<Integer> val;
    vector<int> select_sequence;
    unordered_map<int, int> m;

    ts_init(node_names, leaf_vals, val, select_sequence, m);

    
    vector<Integer> check_val = ex_val_02();
    cout << "vals equal?: " << is_equal(val, check_val) << endl;

    vector<int> check_ss = ex_select_sequence_02();
    cout << "select_sequences equal?: " << equal(select_sequence.begin(), select_sequence.begin() + select_sequence.size(), check_ss.begin()) << endl;

    unordered_map<int, int> check_m = ex_map_02();
    cout << "ms equal?: " << (m == check_m) << endl;
    

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}


void toy_parse_and_sample(BoolIO<NetIO> *ios[threads], int party, size_t num_noises) {
    cout << "a\n";
    
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *rand_io = ios[0];
    // initialize tree
    size_t tree_depth;
    cout << "b\n";
    vector<int> node_names = parse_node_file("noise_presets/test_node_names.txt", tree_depth);
    print_vec(node_names);
    unordered_map<int, int> leaf_vals = parse_leaf_files("noise_presets/test_leaf_names.txt", "noise_presets/test_leaf_elements.txt");
    vector<Integer> val;
    vector<int> select_sequence;
    unordered_map<int, int> m;
    cout << "c\n";
    ts_init(node_names, leaf_vals, val, select_sequence, m);
    cout << "d 1\n";
    // initialize randomness
    vector<Bit> r;
    for (int i=0; i<tree_depth; ++i) {
        r.push_back(Bit(0, PUBLIC));
    }
    vector<Integer> fp_noise_vec;
    for (int i=0; i<num_noises; ++i) {
        fp_noise_vec.push_back(Integer(32, 0, PUBLIC));
    }

    bool alice_local_rand[tree_depth];
    bool bob_local_rand[tree_depth];
    auto start = clock_start();
    cout << "d\n";
    // sampling
    for (int i=0; i<num_noises; ++i) {
        cout << i << "/" << num_noises << endl;
        random_bits(alice_local_rand, tree_depth);
        random_bits(bob_local_rand, tree_depth);
        interactive_seed_generation(rand_io, party, alice_local_rand, bob_local_rand, tree_depth, r);
        fp_noise_vec[i] = tree_sample(val, r, select_sequence, select_sequence.size(), m);
    }

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");

    int t = time_from(start);
    cout << "bench_sigma_1 -- num_noises: " << num_noises << "\t total time (microsec): " << t << "\t time per noise: " << t/(num_noises*1.0) << endl;
}


int main(int argc, char **argv) {
    parse_party_and_port(argv, &party, &port);
    BoolIO<NetIO> *ios[threads];
    for (int i = 0; i < threads; ++i)
        ios[i] = new BoolIO<NetIO>(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + i), party == ALICE);
    //NetIO *randomness_io = new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + 1000);
    //test_circuit_zk(ios, party);
    //ts_init_unit_01(ios, party);
    //ts_init_unit_02(ios, party);
    //tree_sample_unit_01(ios, party);
    //test_seed_gen(ios, party);
    //test_parse_node_file();
    //test_parse_leaf_files();
    toy_parse_and_sample(ios, party, 3);

    for (int i = 0; i < threads; ++i) {
        delete ios[i]->io;
        delete ios[i];
    }
    return 0;
}