#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include "source/phased_ERM.cpp"
#include "source/utils.cpp"

using namespace emp;
using namespace std;

int port, party;
const int threads = 12;


void test_circuit_zk(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    Integer a(32, 3, ALICE);
    Integer b(32, 2, ALICE);
    cout << (a - b).reveal<uint32_t>(PUBLIC) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}


void test_lr(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    vector<double> a_vals = {10.5, 0.0, 0.1, 0.23};
    vector<double> b_vals = {-0.7, 2.1, 57.0, 1.2};
    vector<double> w_vals = {3.4, 0.3, -2.0, 0.5};

    vector<Float> a = init_Float_vec(a_vals);
    vector<Float> b = init_Float_vec(b_vals);
    vector<Float> w = init_Float_vec(w_vals);

    Float a_res = binary_LR_inference(a, w);
    Float b_res = binary_LR_inference(b, w);

    cout << "a res: " << a_res.reveal<double>() << endl;
    cout << "b res: " << b_res.reveal<double>() << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_single_data_grad(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    vector<Float> a = init_Float_vec({10.5, 0.0, 0.1, 0.23});
    Float a_label = Float(1.0, ALICE);
    vector<Float> b = init_Float_vec({-0.7, 2.1, 57.0, 1.2});
    Float b_label = Float(1.0, ALICE);
    vector<Float> w = init_Float_vec({3.4, 0.3, -2.0, 0.5});

    vector<Float> out_grad = Float_zeros(4);

    single_data_grad(a, w, a_label, out_grad, 4);
    cout << "LR gradient for a (a is predicted correctly): \n"; 
    print_Float_vec(out_grad);

    single_data_grad(b, w, b_label, out_grad, 4);
    cout << "LR gradient for b (b is predicted incorrectly): \n";
    print_Float_vec(out_grad);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_regularization(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    vector<Float> w1 = init_Float_vec({3.4, 0.3, -2.0, 0.5});
    vector<Float> w2 = init_Float_vec({3.7, 2.3, -1.5, 1.2});
    vector<Float> w3 = init_Float_vec({3.7 * 5, 2.3 * 2.1, -1.5 * 4, 1.2 * 1.1});
    Float CONST_REG_MULT(1.0/15.0, PUBLIC);
    vector<Float> out_reg = Float_zeros(4);

    regularization_term(w1, w2, CONST_REG_MULT, out_reg, 4);
    cout << "regularization term w1, w2 (small change): \n"; 
    print_Float_vec(out_reg);

    regularization_term(w2, w3, CONST_REG_MULT, out_reg, 4);
    cout << "regularization term w2, w3 (big change): \n"; 
    print_Float_vec(out_reg);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_phased_ERM_grad_check(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    vector<Float> w = init_Float_vec({3.4, 0.3, -0.5, 0.5});
    vector<Float> w_prev = init_Float_vec({3.1, 0.25, -0.7, 0.7});
    vector< vector<Float> > D = init_dummy_D(50, 4); // 20 pts, 4 features
    vector<Float> Y_mostly_correct = bernoulli_labels(50, 0.9);
    vector<Float> Y_mostly_incorrect = bernoulli_labels(50, 0.1);

    vector<int> Di_indices = indices_range(0, 25); // first half of dataset 
    size_t batch_sz = 25;

    double nu_i = 2.0; // TODO: what is a reasonable value for this to take?
    double n_i = 25.0;
    double L = 2; // TODO: what is a reasonable value for lipschitz constant?

    Float CONST_REG_MULT(1.0/(nu_i * n_i), PUBLIC);
    Float CONST_GRADSUM_MULT(1.0/n_i, PUBLIC);
    Float CONST_GRADBOUND_SQUARED( (2.0/n_i * L) *(2.0/n_i * L), PUBLIC);

    vector<Float> out_grad = Float_zeros(4);
    size_t sz_outgrad = 4;

    phased_ERM_grad_check(D, Y_mostly_correct, Di_indices, batch_sz, w, w_prev, CONST_REG_MULT, CONST_GRADSUM_MULT, CONST_GRADBOUND_SQUARED, out_grad, sz_outgrad);
    cout << "grad F for mostly correct labels: \n"; 
    print_Float_vec(out_grad);
    cout << "squared sum: " << Float_vec_squared_sum(out_grad).reveal<double>();
    cout << "       const: " << CONST_GRADBOUND_SQUARED.reveal<double>() << endl;

    phased_ERM_grad_check(D, Y_mostly_incorrect, Di_indices, batch_sz, w, w_prev, CONST_REG_MULT, CONST_GRADSUM_MULT, CONST_GRADBOUND_SQUARED, out_grad, sz_outgrad);
    cout << "grad F for mostly incorrect labels: \n"; 
    print_Float_vec(out_grad);
    cout << "squared sum: " << Float_vec_squared_sum(out_grad).reveal<double>();
    cout << "       const: " << CONST_GRADBOUND_SQUARED.reveal<double>() << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

void test_phased_erm(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    vector<Float> w = init_Float_vec({3.4, 0.3, -0.5, 0.5});
    vector<Float> w_prev = init_Float_vec({3.1, 0.25, -0.7, 0.7});
    vector< vector<Float> > D = init_dummy_D(50, 4); // 20 pts, 4 features
    vector<Float> Y_mostly_correct = bernoulli_labels(50, 0.9);
    vector<Float> Y_mostly_incorrect = bernoulli_labels(50, 0.1);

    double L = 2.0; // TODO: what is a reasonable value for L
    double step_size = 0.1; // TODO what is a reasonable value for step_size
    size_t num_params = 4;
    size_t rand_bit_sz = 2;

    BoolIO<NetIO> *io = ios[0];

    vector<int> node_names = ex_node_names_01();
    unordered_map<int, int> leaf_vals = ex_leaf_vals_01();

    phased_ERM(D, Y_mostly_correct, L, w, step_size, num_params, rand_bit_sz, node_names, leaf_vals, io, party);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}

int main(int argc, char **argv) {
    parse_party_and_port(argv, &party, &port);
    BoolIO<NetIO> *ios[threads];
    for (int i = 0; i < threads; ++i)
        ios[i] = new BoolIO<NetIO>(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + i), party == ALICE);
    //NetIO *randomness_io = new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + 1000);
    //test_circuit_zk(ios, party);
    //test_lr(ios, party);
    //test_single_data_grad(ios, party);
    //test_regularization(ios, party);
    //test_phased_ERM_grad_check(ios, party);
    test_phased_erm(ios, party);

    for (int i = 0; i < threads; ++i) {
        delete ios[i]->io;
        delete ios[i];
    }
    return 0;
}