#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include "source/phased_ERM.cpp"
#include "source/utils.cpp"

using namespace emp;
using namespace std;


int port, party;
const int threads = 12;

const size_t MNIST_N_FEATURES = 28*28;
const size_t MNIST_N_DATA = 60;

void test_circuit_zk(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    Integer a(32, 3, ALICE);
    Integer b(32, 2, ALICE);
    cout << (a - b).reveal<uint32_t>(PUBLIC) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}


void bench_phased_erm(BoolIO<NetIO> *ios[threads], int party, size_t n_data, size_t n_features) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    auto start = clock_start();
    vector<Float> w = Float_ones(n_features);
    vector< vector<Float> > D = init_dummy_D(n_data, n_features); // 20 pts, 4 features
    vector<Float> Y_mostly_correct = bernoulli_labels(n_data, 0.9);
    //vector<Float> Y_mostly_incorrect = bernoulli_labels(50, 0.1);

    double L = 2.0; // TODO: what is a reasonable value for L
    double step_size = 0.1; // TODO what is a reasonable value for step_size
    size_t num_params = n_features;
    size_t rand_bit_sz = 2;

    BoolIO<NetIO> *io = ios[0];

    vector<int> node_names = ex_node_names_01();
    unordered_map<int, int> leaf_vals = ex_leaf_vals_01();

    phased_ERM(D, Y_mostly_correct, L, w, step_size, num_params, rand_bit_sz, node_names, leaf_vals, io, party);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
    cout << "n_data: " << n_data << "  n_features: " << n_features << "\t time (sec): " << time_from(start) / (1000000.0) << " " << party << endl;
} 

void bench_phased_erm_MNIST(BoolIO<NetIO> *ios[threads], int party, vector<string> node_files, vector<string> leaf_name_files, vector<string> leaf_val_files) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    // load noise trees from disk
    vector< vector<int> > node_nameses;
    vector<size_t> tree_depths;
    vector< unordered_map<int, int> > leaf_valses;
    size_t n_noise_files = node_files.size();
    size_t temp_depth;
    for (int i=0; i<n_noise_files; ++i) {
        node_nameses.push_back(parse_node_file(node_files[i], temp_depth));
        tree_depths.push_back(temp_depth);
        leaf_valses.push_back(parse_leaf_files(leaf_name_files[i], leaf_val_files[i]));
    }
    print_vec(tree_depths);

    auto start = clock_start();
    BoolIO<NetIO> *io = ios[0];
    #ifdef DEBUG
    cout << "bench_phased_erm_MNIST -- initializing data" << endl;
    #endif
    
    vector<Float> w = Float_ones(MNIST_N_FEATURES);
    vector< vector<Float> > D = init_dummy_D(1, MNIST_N_FEATURES); // D small for memory, but we will still iterate the proper number of times
    vector<Float> Y_mostly_correct = bernoulli_labels(1, 0.9);

    double L = MNIST_N_FEATURES; 
    double step_size = 0.1; // TODO what is a reasonable value for step_size

    #ifdef DEBUG
    //return;
    #endif

    phased_ERM_multinoise(D, Y_mostly_correct, L, w, step_size, MNIST_N_FEATURES, tree_depths, node_nameses, leaf_valses, MNIST_N_DATA, io, party);

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
    cout << "n_data: " << MNIST_N_DATA << "  n_features: " << MNIST_N_FEATURES << "\t time: " << time_from(start) << " " << party << endl;

}

int main(int argc, char **argv) {
    parse_party_and_port(argv, &party, &port);
    BoolIO<NetIO> *ios[threads];
    for (int i = 0; i < threads; ++i)
        ios[i] = new BoolIO<NetIO>(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + i), party == ALICE);
    //NetIO *randomness_io = new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + 1000);
    //test_circuit_zk(ios, party);
    //test_lr(ios, party);
    //test_single_data_grad(ios, party);
    //test_regularization(ios, party);
    //test_phased_ERM_grad_check(ios, party);
    vector<string> node_files;
    vector<string> leaf_name_files;
    vector<string> leaf_val_files;
    string sigma = "1";
    for (int i=1; i<=16; ++i) {
        string s = "noise_presets/MNIST_";
        s = s + sigma + "/";
        string nf = s + "tree_names_sigma_" + to_string(i) + ".txt";
        node_files.push_back(nf);
        string ln = s + "leaf_names_sigma_" + to_string(i) + ".txt";
        leaf_name_files.push_back(ln);
        string lv = s +  "leaf_elements_sigma_" + to_string(i) + ".txt";
        leaf_val_files.push_back(lv);
    }

    bench_phased_erm_MNIST(ios, party, node_files, leaf_name_files, leaf_val_files);
    cout << "bench_phased_erm_MNIST finish -- sigma: " << sigma << endl;

    for (int i = 0; i < threads; ++i) {
        delete ios[i]->io;
        delete ios[i];
    }
    return 0;
}