#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include "source/phased_ERM.cpp"
#include "source/utils.cpp"

using namespace emp;
using namespace std;


int port, party;
const int threads = 12;

const size_t MNIST_N_FEATURES = 28*28;
//const size_t MNIST_N_DATA = 60000;

void test_circuit_zk(BoolIO<NetIO> *ios[threads], int party) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);
    Integer a(32, 3, ALICE);
    Integer b(32, 2, ALICE);
    cout << (a - b).reveal<uint32_t>(PUBLIC) << endl;

    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
}


void bench_gradient(BoolIO<NetIO> *ios[threads], int party, size_t n_data, size_t n_features) {
    setup_zk_bool<BoolIO<NetIO>>(ios, threads, party);

    auto start = clock_start();
    
    vector<Float> w = Float_ones(n_features);
    vector<Float> g_out = Float_zeros(n_features);
    vector< vector<Float> > D = init_dummy_D(1, n_features); // D small for memory, but we will still iterate the proper number of times
    vector<Float> Y_mostly_correct = bernoulli_labels(1, 0.9);

    for (int i=0; i<n_data; ++i) {
        single_data_grad(D[0], w, Y_mostly_correct[0], g_out, n_features);
    }
    bool cheat = finalize_zk_bool<BoolIO<NetIO>>();
    if (cheat)error("cheat!\n");
    double t = time_from(start);
    cout << "n_data: " << n_data << "  n_features: " << n_features << "\t time: " << t << "\t time/gradient: " << t/(1.0*n_data) << " " << party << endl;
}

int main(int argc, char **argv) {
    parse_party_and_port(argv, &party, &port);
    BoolIO<NetIO> *ios[threads];
    for (int i = 0; i < threads; ++i)
        ios[i] = new BoolIO<NetIO>(new NetIO(party == ALICE ? nullptr : "127.0.0.1", port + i), party == ALICE);


    bench_gradient(ios, party, 500, MNIST_N_FEATURES);

    for (int i = 0; i < threads; ++i) {
        delete ios[i]->io;
        delete ios[i];
    }
    return 0;
}