#include "src/utils/poly.h"
#include "src/utils/fields/fp59.h"

// Function to create an initial batch of packed secret sharings
template<typename T>
void create_init_shares(PRG &prg, Poly<T> &poly, std::size_t n, std::size_t dplusone, std::size_t k, 
        std::vector<std::vector<T>> &secrets, std::vector<std::vector<T>> &shares) {

    secrets.resize(k);
    shares.resize(n);
    for (std::size_t i = 0; i < n; ++i) {
        shares[i].resize(k);
    }

    std::vector<std::vector<T>> defining_pts(k);
    for (std::size_t i = 0; i < k; ++i) {
        secrets[i].resize(k);
        defining_pts[i].resize(dplusone);
        std::vector<uint64_t> buf(dplusone);
        prg.random_data(buf.data(), dplusone * sizeof(uint64_t));
        for (std::size_t j = 0; j < dplusone; ++j) {
            defining_pts[i][j] = T(buf[j]);
            if (j < k) {
                secrets[i][j] = defining_pts[i][j];
            }
        }
        
        std::vector<T> shares_for_secrets(n);
        poly.nttEvalK2N(shares_for_secrets, defining_pts[i]);
        for (std::size_t j = 0; j < n; ++j) {
            shares[j][i] = shares_for_secrets[j];
        }
    }
}

// Function that implements our packed resharing protocol from Section 3 for a given party
template<typename T>
void reshare(PRG &prg, Poly<T> &poly, std::size_t n, std::size_t dplusone, std::size_t k, std::size_t id,
        std::vector<T> &in_shares, std::vector<std::vector<T>> &out_shares) {
    std::vector<T> new_shares(n);

    std::vector<T>defining_pts = { in_shares.begin(), in_shares.begin() + k };
    defining_pts.resize(dplusone);
    std::vector<uint64_t> buf(dplusone-k);
    prg.random_data(buf.data(), (dplusone-k) * sizeof(uint64_t));
    for (std::size_t i = k; i < dplusone; ++i) {
        defining_pts[i] = T(buf[i-k]);
    }

    poly.nttEvalK2N(new_shares, defining_pts);
    for (std::size_t i = 0; i < n; ++i) {
        out_shares[i][id] = new_shares[i];
    }
}

// Function that implements our Recover protocol from Section 3 for a given party
template<typename T>
void get_secret_shares(Poly<T> &poly, std::size_t n, std::size_t dplusone, std::size_t k,
        std::vector<T> &in_shares, std::vector<T> &out_shares) {
    out_shares.resize(k);
    std::vector<T> defining_pts(dplusone);

    poly.nttEvalN2K(defining_pts, in_shares);
    out_shares = { defining_pts.begin(), defining_pts.begin() + k };
}

// Function that implements our Parity Check protocol from Appendix B for a given party
template<typename T>
void get_parity_check_shares(Poly<T> &poly, std::size_t n, std::size_t dplusone, std::vector<T> &in_shares, std::vector<T> &out_shares) {
    out_shares.resize(n);

    std::vector<T> interp_shares(dplusone);
    poly.nttEvalN2K(interp_shares, in_shares);

    std::vector<T> check_shares(n);
    poly.nttEvalK2N(check_shares, interp_shares);
    for (std::size_t j = 0; j < n; ++j) {
        out_shares[j] = in_shares[j] - check_shares[j];
    }
}

// Function that takes a random linear combination over secret shares, specified by beta
template<typename T>
void get_random_comb(std::size_t num_batches, std::size_t n, T beta, std::vector<std::vector<T>> &in_shares, T &comb_share) {
    comb_share = T(0);
    T beta_powers = beta;
    for (std::size_t batch = 0; batch < num_batches; ++batch) {
        for (std::size_t i = 0; i < n; ++i) {
            comb_share = comb_share + beta_powers * in_shares[batch][i];
            beta_powers = beta_powers * beta_powers;
        }
    }
}

int main(int argc, char** argv) {
    if (argc != 5) {
        std:: cout << "usage: [committee size n] [sharings degree d] "
        "[\\eps for corruption threshold t < (1/2-\\eps)n] [num_batches b]" << std::endl;
        return 0;
    }
    std::size_t n = std::stoul(argv[1]);
    std::size_t d = std::stoul(argv[2]);
    float eps = std::stof(argv[3]);
    std::size_t num_batches = std::stoul(argv[4]);
    
    std::size_t dplusone = d+1;
    float floatd = static_cast<float>(d);
    float floatn  = static_cast<float>(n);
    float floatk = floatd + 1 - (0.5-eps)*floatn;
    std::size_t k = static_cast<std::size_t>(floatk);
    //std::cout << log2(k) << std::endl;
    // std::cout << (log2(k) == ceil(log2(k))) << std::endl;

    PRG prg;

    Poly<FP59> poly(log2(dplusone), log2(n));

    std::vector<std::vector<std::vector<FP59>>> secrets(num_batches), shares(num_batches),
        new_shares(num_batches), shares_of_secrets(num_batches), parity_check_shares(n);

    for (std::size_t i = 0; i < n; ++i) {
        parity_check_shares[i].resize(num_batches);
    }

    std::vector<double> reshare_times, get_secret_shares_times, get_parity_check_shares_times;
    reshare_times.reserve(num_batches * n);
    get_secret_shares_times.reserve(num_batches * n);
    get_parity_check_shares_times.reserve(num_batches * n);

    for (std::size_t batch = 0; batch < num_batches; ++batch) {
        secrets[batch].resize(k);
        shares[batch].resize(n);
        // generate an initial batch of packed secret sharings
        create_init_shares<FP59>(prg, poly, n, dplusone, k, secrets[batch], shares[batch]);

        new_shares[batch].resize(n);
        for (std::size_t i = 0; i < n; ++i) {
            new_shares[batch][i].resize(n);
        }

        // Use our Resharing protocol from Section 3 for each party
        for (std::size_t i = 0; i < n; ++i) {
            auto start = emp::clock_start();
            reshare<FP59>(prg, poly, n, dplusone, k, i, shares[batch][i], new_shares[batch]);
            // std::cout << emp::time_from(start) << std::endl;
            reshare_times.push_back(emp::time_from(start));
        }

        // Use our Recover protocol from Section 3 for each party
        shares_of_secrets[batch].resize(n);
        for (std::size_t i = 0; i < n; ++i) {
            auto start = emp::clock_start();
            get_secret_shares<FP59>(poly, n, dplusone, k, new_shares[batch][i], shares_of_secrets[batch][i]);
            get_secret_shares_times.push_back(emp::time_from(start));
        }

        // Use our Parity Check protocol from Appendix B for each party
        for (std::size_t i = 0; i < n; ++ i) {
            auto start = emp::clock_start();
            get_parity_check_shares<FP59>(poly, n, dplusone, new_shares[batch][i], parity_check_shares[i][batch]);
            get_parity_check_shares_times.push_back(emp::time_from(start));
        }
    }

    std::vector<double> rand_comb_times;
    rand_comb_times.reserve(n);
    std::vector<uint64_t> buf(1);
    prg.random_data(buf.data(), sizeof(uint64_t));
    FP59 beta = FP59(buf[0]);

    // take a random linear combination of the parity check shares for each party, specified by beta
    std::vector<FP59> random_comb(n);
    for (std::size_t i = 0; i < n; ++i) {
        auto start = emp::clock_start();
        get_random_comb<FP59>(num_batches, n, beta, parity_check_shares[i], random_comb[i]);
        rand_comb_times.push_back(emp::time_from(start));
    }

    // open the combined parity check sharing
    std::vector<FP59> opened_rand_comb(k);
    poly.nttEvalN2K(opened_rand_comb, random_comb);
    for (std::size_t i = 0; i < k; ++i) {
        // std::cout << opened_rand_comb[i].val << std::endl;
        if (opened_rand_comb[i] != FP59(0)) {
            std::cout << "oops!" << std::endl;
        }
    }

    double avg_reshare_time = accumulate(reshare_times.begin(), reshare_times.end(), 0.0) / reshare_times.size();
    double avg_get_secret_share_time = accumulate(get_secret_shares_times.begin(), get_secret_shares_times.end(), 0.0) /
        get_secret_shares_times.size();
    double avg_get_parity_check_shares_time = accumulate(get_parity_check_shares_times.begin(), get_parity_check_shares_times.end(), 0.0) /
        get_parity_check_shares_times.size();
    double avg_rand_comb_time = accumulate(rand_comb_times.begin(), rand_comb_times.end(), 0.0) / rand_comb_times.size();
    std::cout << "time per reshare " << avg_reshare_time << " time per getting new secret share: " 
        << avg_get_secret_share_time << " time per getting parity check shares: "
        << avg_get_parity_check_shares_time << " time per rand comb: "
        << avg_rand_comb_time << std::endl;

   return 0;
}