#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include "source/sample.cpp"
#include "source/utils.cpp"

using namespace emp;
using namespace std;

int port, party;
const int threads = 12;


void test_circuit_zk(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    Integer a(32, 3, ALICE);
    Integer b(32, 2, ALICE);
    cout << (a - b).reveal<uint32_t>(PUBLIC) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_random_challenge(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *io = ios[0];
    
    PRG prg;
    block r;
    prg.random_block(&r, 1);
    bool r_in_bits[128];
    block_to_bool(r_in_bits, r);
    Bit a(r_in_bits[0], ALICE);
    bool s_in_bits[128];
    
    if (party == ALICE) {
        block s;
        io->recv_block(&s, 1);
        block_to_bool(s_in_bits, s);
    } else {
        block s;
        PRG prg;
        prg.random_block(&s, 1);
        io->send_block(&s, 1);
        io->flush();
        block_to_bool(s_in_bits, s);
    }
    
    Bit b(s_in_bits[0], PUBLIC);
    Bit c = a ^ b;
    cout << "random coin: " << c.reveal() << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void bench_noise_preset(BoolIO<NetIO> *ios[threads], int party, string node_file, string leaf_name_file, string leaf_vals_file, size_t num_noises) {
    //cout << "a\n";
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *rand_io = ios[0];
    // initialize tree
    size_t tree_depth;
    //cout << "b\n";
    vector<int> node_names = parse_node_file(node_file, tree_depth);
    //print_vec(node_names);
    unordered_map<int, int> leaf_vals = parse_leaf_files(leaf_name_file, leaf_vals_file);
    vector<Integer> val;
    vector<int> select_sequence;
    unordered_map<int, int> m;
    //cout << "c\n";
    ts_init(node_names, leaf_vals, val, select_sequence, m);
    //cout << "d 1\n";
    // initialize randomness
    vector<Bit> r;
    for (int i=0; i<tree_depth; ++i) {
        r.push_back(Bit(0, PUBLIC));
    }
    vector<Integer> fp_noise_vec;
    for (int i=0; i<num_noises; ++i) {
        fp_noise_vec.push_back(Integer(32, 0, PUBLIC));
    }

    bool alice_local_rand[tree_depth];
    bool bob_local_rand[tree_depth];
    vector< vector<Bit> > rs;
    auto start = clock_start();

    for (int i=0; i<num_noises; ++i) {
        random_bits(alice_local_rand, tree_depth);
        random_bits(bob_local_rand, tree_depth);
        interactive_seed_generation(rand_io, party, alice_local_rand, bob_local_rand, tree_depth, r);
        rs.push_back(r);
    }
    int t_seedgen = time_from(start);
    //cout << "d\n";
    // sampling
    auto s2 = clock_start();
    for (int i=0; i<num_noises; ++i) {
        //cout << i << "/" << num_noises << endl;
        fp_noise_vec[i] = tree_sample(val, rs[i], select_sequence, select_sequence.size(), m);
    }
    int t_tree_sample = time_from(s2);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");

    int t = time_from(start);
    cout << "bench_marking: " << node_file << " -- num_noises: " << num_noises << "\t total time (sec): " << t/(1.0*1000000) << "\t time per noise (ms): " << t/(num_noises*1000.0) << endl;
    cout << "tree depth: " << tree_depth << "  num_noises: " << num_noises << "\t seed generation (sec): " << t_seedgen/(1000000.0) << "   per noise (ms): " << t_seedgen/(1000.0*num_noises) << "\t sample time (sec): " << t_tree_sample/(1000000.0) << "  per noise (ms): " << t_tree_sample/(num_noises*1000.0) << endl;
}

void bench_sigma_1(BoolIO<NetIO> *ios[threads], int party, size_t num_noises) {
    //cout << "a\n";
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *rand_io = ios[0];
    // initialize tree
    size_t tree_depth;
    //cout << "b\n";
    vector<int> node_names = parse_node_file("noise_presets/tree_names_sigma_1.txt", tree_depth);
    //print_vec(node_names);
    unordered_map<int, int> leaf_vals = parse_leaf_files("noise_presets/leaf_names_sigma_1.txt", "noise_presets/leaf_elements_sigma_1.txt");
    vector<Integer> val;
    vector<int> select_sequence;
    unordered_map<int, int> m;
    //cout << "c\n";
    ts_init(node_names, leaf_vals, val, select_sequence, m);
    //cout << "d 1\n";
    // initialize randomness
    vector<Bit> r;
    for (int i=0; i<tree_depth; ++i) {
        r.push_back(Bit(0, PUBLIC));
    }
    vector<Integer> fp_noise_vec;
    for (int i=0; i<num_noises; ++i) {
        fp_noise_vec.push_back(Integer(32, 0, PUBLIC));
    }

    bool alice_local_rand[tree_depth];
    bool bob_local_rand[tree_depth];
    vector< vector<Bit> > rs;
    auto start = clock_start();

    for (int i=0; i<num_noises; ++i) {
        random_bits(alice_local_rand, tree_depth);
        random_bits(bob_local_rand, tree_depth);
        interactive_seed_generation(rand_io, party, alice_local_rand, bob_local_rand, tree_depth, r);
        rs.push_back(r);
    }
    int t_seedgen = time_from(start);
    //cout << "d\n";
    // sampling
    auto s2 = clock_start();
    for (int i=0; i<num_noises; ++i) {
        //cout << i << "/" << num_noises << endl;
        fp_noise_vec[i] = tree_sample(val, rs[i], select_sequence, select_sequence.size(), m);
    }
    int t_tree_sample = time_from(s2);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");

    int t = time_from(start);
    cout << "bench_sigma_1 -- num_noises: " << num_noises << "\t total time (sec): " << t/(1.0*1000000) << "\t time per noise (ms): " << t/(num_noises*1000.0) << endl;
    cout << "bench_sigma_1 -- num_noises: " << num_noises << "\t seed generation (sec): " << t_seedgen/(1000000.0) << "   per noise (ms): " << t_seedgen/(1000.0*num_noises) << "\t sample time (sec): " << t_tree_sample/(1000000.0) << "  per noise (ms): " << t_tree_sample/(num_noises*1000.0) << endl;
}



int main(int argc, char **argv) {
    parse_party_and_port(argv, &party, &port);
    BoolIO<NetIO> *ios[threads];
    for (int i = 0; i < threads; ++i)
        ios[i] = new BoolIO<NetIO>(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + i), party == ALICE);
    //test_circuit_zk(ios, party);
    //test_random_challenge(ios, party);
    //bench_sigma_1(ios, party, 28*28);
    string s = "noise_presets/MNIST_sigma_100/";
    bench_noise_preset(ios, party, s + "tree_names.txt", s + "leaf_names.txt", s+"leaf_elements.txt", 10000);

    //string t = "noise_presets/";
    //bench_noise_preset(ios, party, t + "tree_names_sigma_1.txt", t + "leaf_names_sigma_1.txt", t + "leaf_elements_sigma_1.txt", 1000);

    for (int i = 0; i < threads; ++i) {
        delete ios[i]->io;
        delete ios[i];
    }
    return 0;
}