#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include "source/utils.cpp"

using namespace emp;
using namespace std;

int port, party;
const int threads = 12;


void test_circuit_zk(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    Integer a(32, 3, ALICE);
    Integer b(32, 2, ALICE);
    cout << (a - b).reveal<uint32_t>(PUBLIC) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_random_challenge(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    BoolIO<NetIO> *io = ios[0];
    
    PRG prg;
    block r;
    prg.random_block(&r, 1);
    bool r_in_bits[128];
    block_to_bool(r_in_bits, r);
    Bit a(r_in_bits[0], ALICE);
    bool s_in_bits[128];
    
    if (party == ALICE) {
        block s;
        io->recv_block(&s, 1);
        block_to_bool(s_in_bits, s);
    } else {
        block s;
        PRG prg;
        prg.random_block(&s, 1);
        io->send_block(&s, 1);
        io->flush();
        block_to_bool(s_in_bits, s);
    }
    
    Bit b(s_in_bits[0], PUBLIC);
    Bit c = a ^ b;
    cout << "random coin: " << c.reveal() << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void print_Int_bits(Integer x, size_t int_sz=32) {
    for (int i=0; i<int_sz; ++i) {
        //cout << "[" << i << ": " << x[i].reveal() << "] ";
        cout << x[i].reveal();
    }
    cout << endl;
}

void print_c_float_bits(float x) {
    uint32_t fw = *(uint32_t*) &x; // interpret bits of x as an int
    for (int i=0; i<32; ++i) {
        cout << (fw & 1);
        fw = fw >> 1;
    }
    cout << endl;
}

void test_float_word(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    Float x(-2.5, ALICE);
    Integer y = float_word(x);

    print_Int_bits(y);
    float c = -2.5;
    print_c_float_bits(c);
    
    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_int_relu(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    Integer a(32, -5, PUBLIC);
    Integer b(32, 15, PUBLIC);

    cout << int_relu(a).reveal<int>() << endl;
    cout << int_relu(b).reveal<int>() << endl;
    
    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_float_to_int(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    Float a(15.5, PUBLIC);
    Float b(0.003, PUBLIC);
    Float c(10002, PUBLIC);
    Float d(1<<27, PUBLIC);
    Float e(0.0, PUBLIC);
    Float f(4.203895392974451212771189E-45, PUBLIC);

    cout << float_to_int(a).reveal<int>() << endl;
    cout << float_to_int(b).reveal<int>() << endl;
    cout << float_to_int(c).reveal<int>() << endl;
    cout << float_to_int(d).reveal<int>() << " should be: " << (1<<27) << endl;
    cout << float_to_int(e).reveal<int>() << endl;
    cout << float_to_int(f).reveal<int>() << endl;
    
    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_int_to_float(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    Integer a(32, 1, PUBLIC);
    Integer b(32, -1, PUBLIC);
    Integer c(32, 0, PUBLIC);
    Integer d(32, 2, PUBLIC);
    Integer e(32, 5, PUBLIC);
    Integer f(32, (1<<28)+(1<<27) * 1.0, PUBLIC);
    Integer g(32, 15, PUBLIC);

    cout << int_to_float(a).reveal<double>() << endl;
    cout << int_to_float(b).reveal<double>() << endl;
    cout << int_to_float(c).reveal<double>() << endl;
    cout << int_to_float(d).reveal<double>() << endl;
    cout << int_to_float(e).reveal<double>() << endl;
    cout << int_to_float(f).reveal<double>() << " should be: " << ((1<<28)+(1<<27)*1.0) << endl;
    cout << int_to_float(g).reveal<double>() << endl;
    
    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}


int main(int argc, char **argv) {
    parse_party_and_port(argv, &party, &port);
    BoolIO<NetIO> *ios[threads];
    for (int i = 0; i < threads; ++i)
        ios[i] = new BoolIO<NetIO>(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + i), party == ALICE);
    //test_circuit_zk(ios, party);
    //test_random_challenge(ios, party);
    //test_float_word(ios, party);
    //test_int_relu(ios, party);
    //test_float_to_int(ios, party);
    test_int_to_float(ios, party);

    for (int i = 0; i < threads; ++i) {
        delete ios[i]->io;
        delete ios[i];
    }
    return 0;
}