/*
 * LowGearKeyGen.cpp
 *
 */

#include "LowGearKeyGen.h"
#include "FHE/Rq_Element.h"

#include "Tools/benchmarking.h"

#include "Machines/SPDZ.hpp"
#include "ShareVector.hpp"

template<int L>
LowGearKeyGen<L>::LowGearKeyGen(Player& P, PairwiseMachine& machine,
        FHE_Params& params) :
        KeyGenProtocol<5, L>(P, params), P(P), machine(machine)
{
}

template<int X, int L>
KeyGenProtocol<X, L>::KeyGenProtocol(Player& P, const FHE_Params& params,
        int level) :
        P(P), params(params), fftd(params.FFTD().at(level)), usage(P)
{
    open_type::init_field(params.FFTD().at(level).get_prD().pr);
    typename share_type::mac_key_type alphai;

    if (OnlineOptions::singleton.live_prep)
    {
        prep = new MascotDabitOnlyPrep<share_type>(0, usage);
        alphai.randomize(G);
    }
    else
    {
        prep = new Sub_Data_Files<share_type>(P.N,
                get_prep_sub_dir<share_type>(P.num_players()), usage);
        read_mac_key(get_prep_sub_dir<share_type>(P.num_players()), P.N, alphai);
    }

    MC = new MAC_Check_<share_type>(alphai);
    proc = new SubProcessor<share_type>(*MC, *prep, P);
}

template<int X, int L>
KeyGenProtocol<X, L>::~KeyGenProtocol()
{
    MC->Check(P);

    usage.print_cost();

    delete proc;
    delete prep;
    delete MC;
}

template<int X, int L>
void KeyGenProtocol<X, L>::input(vector<vector_type>& shares, const Rq_Element& secret)
{
    assert(secret.level() == 0);
    auto s = secret.get(0);
    s.change_rep(evaluation);
    auto& FFTD = s.get_FFTD();
    auto& inputter = this->proc->input;
    inputter.reset_all(P);
    for (int i = 0; i < FFTD.num_slots(); i++)
        inputter.add_from_all(s.get_element(i));
    inputter.exchange();
    shares.clear();
    shares.resize(P.num_players());
    for (int i = 0; i < FFTD.num_slots(); i++)
        for (int j = 0; j < P.num_players(); j++)
            shares[j].push_back(inputter.finalize(j));
}

template<int X, int L>
template<class T>
void KeyGenProtocol<X, L>::binomial(vector_type& shares, T& prep)
{
    shares.resize(params.phi_m());
    RunningTimer timer, total;
    for (int i = 0; i < params.phi_m(); i++)
    {
#ifdef VERBOSE
        if (timer.elapsed() > 10)
        {
            cerr << i << "/" << params.phi_m() << ", throughput " <<
                    i / total.elapsed() << endl;
            timer.reset();
        }
#endif

        auto& share = shares[i];
        share = {};
        for (int i = 0; i < params.get_DG().get_NewHopeB(); i++)
        {
            share += prep.get_bit();
            share -= prep.get_bit();
        }
    }
    shares.fft(fftd);
}

template<int X, int L>
template<class T>
void KeyGenProtocol<X, L>::hamming(vector_type& shares, T& prep)
{
    shares.resize(params.phi_m());
    int h = params.get_h();
    assert(h > 0);
    assert(shares.size() / h * h == shares.size());
    int n_bits = log(shares.size() / h) / log(2);
//    assert(size_t(h << n_bits) == shares.size());

    for (auto& share : shares)
        share = prep.get_bit();

    auto& protocol = proc->protocol;
    for (int i = 0; i < n_bits - 1; i++)
    {
        protocol.init_mul(proc);
        for (auto& share : shares)
            protocol.prepare_mul(share, prep.get_bit());
        protocol.exchange();
        for (auto& share : shares)
            share = protocol.finalize_mul();
    }

    protocol.init_mul(proc);
    auto one = share_type::constant(1, P.my_num(), MC->get_alphai());
    for (auto& share : shares)
        protocol.prepare_mul(share, prep.get_bit() * 2 - one);
    protocol.exchange();
    for (auto& share : shares)
        share = protocol.finalize_mul();

    shares.fft(fftd);
}

template<int X, int L>
template<class T>
void KeyGenProtocol<X, L>::secret_key(vector_type& shares, T& prep)
{
    assert(params.get_h() != 0);
    cerr << "Generate secret key by ";
    if (params.get_h() > 0)
    {
        cerr << "Hamming weight"  << endl;
        hamming(shares, prep);
    }
    else
    {
        cerr << "binomial" << endl;
        binomial(shares, prep);
    }
}

template<int X, int L>
typename KeyGenProtocol<X, L>::vector_type KeyGenProtocol<X, L>::schur_product(
        const vector_type& x, const vector_type& y)
{
    vector_type res;
    assert(x.size() == y.size());
    auto& protocol = proc->protocol;
    protocol.init_mul(proc);
    for (size_t i = 0; i < x.size(); i++)
        protocol.prepare_mul(x[i], y[i]);
    protocol.exchange();
    for (size_t i = 0; i < x.size(); i++)
        res.push_back(protocol.finalize_mul());
    return res;
}

template<int X, int L>
void KeyGenProtocol<X, L>::output_to(int player, vector<open_type>& opened,
        vector<share_type>& shares)
{
    PrivateOutput<share_type> po(*proc);
    vector<share_type> masked;
    for (auto& share : shares)
        masked.push_back(po.start(player, share));
    MC->POpen(opened, masked, P);
    for (auto& x : opened)
        x = po.stop(player, x);
}

template<int L>
void LowGearKeyGen<L>::generate_keys(FHE_Params& params)
{
    RunningTimer timer;
    auto& pk = machine.pk;

    GlobalPRNG global_prng(P);
    auto& FFTD = pk.get_params().FFTD()[0];

    for (int i = 0; i < P.num_players(); i++)
    {
        vector_type sk;
        this->secret_key(sk, *this);
        vector<open_type> open_sk;
        this->output_to(i, open_sk, sk);
        if (P.my_num() == i)
            machine.sk.assign(Ring_Element(FFTD, evaluation, open_sk));
        vector_type e0;
        this->binomial(e0, *this);
        AddableVector<open_type> a0(pk.get_params().phi_m());
        a0.randomize(global_prng);
        vector<open_type> b0;
        assert(machine.sk.p() != 0);
        this->MC->POpen(b0, sk * a0 + e0 * machine.sk.p(), P);
        machine.other_pks[i] = FHE_PK(params, machine.sk.p());
        machine.other_pks[i].assign(Ring_Element(FFTD, evaluation, a0),
                Ring_Element(FFTD, evaluation, b0));
    }

    this->MC->Check(P);

    cerr << "Key generation took " << timer.elapsed() << " seconds" << endl;
}

template<int L>
template<class FD>
void LowGearKeyGen<L>::run(PairwiseSetup<FD>& setup)
{
    generate_keys(setup.params);
    machine.sk.check(machine.pk, setup.FieldD);

    RunningTimer timer;

    auto mac_key = SeededPRNG().get<typename FD::T>();

    PairwiseGenerator<FD> generator(0, machine, &P);
    map<string, Timer> timers;
    MultiEncCommit<FD> EC(P, machine.other_pks, setup.FieldD,
            timers, machine, generator, true);
    assert(EC.proof.get_diagonal());
    vector<Plaintext_<FD>> m(EC.proof.U, setup.FieldD);
    for (auto& mm : m)
        mm.assign_constant(mac_key);

    AddableVector<Ciphertext> C;
    octetStream ciphertexts, cleartexts;
    EC.generate_proof(C, m, ciphertexts, cleartexts);

    AddableVector<Ciphertext> others_ciphertexts;
    others_ciphertexts.resize(EC.proof.U, machine.pk.get_params());
    Verifier<FD> verifier(EC.proof, setup.FieldD);
    verifier.NIZKPoK(others_ciphertexts, ciphertexts,
            cleartexts, machine.pk, false);

    machine.enc_alphas.clear();
    for (int i = 0; i < P.num_players(); i++)
        machine.enc_alphas.push_back(machine.other_pks[i]);

    for (int i = 1; i < P.num_players(); i++)
    {
        int player = P.get_player(-i);
#ifdef VERBOSE_HE
        cerr << "Sending proof with " << 1e-9 * ciphertexts.get_length() << "+"
                << 1e-9 * cleartexts.get_length() << " GB" << endl;
#endif
        timers["Sending"].start();
        P.pass_around(ciphertexts);
        P.pass_around(cleartexts);
        timers["Sending"].stop();
#ifdef VERBOSE_HE
        cerr << "Checking proof of player " << i << endl;
#endif
        timers["Verifying"].start();
        verifier.NIZKPoK(others_ciphertexts, ciphertexts,
                cleartexts, machine.other_pks[player], false);
        timers["Verifying"].stop();
        machine.enc_alphas.at(player) = others_ciphertexts.at(0);
    }

    setup.set_alphai(mac_key);
    machine.enc_alphas.at(P.my_num()) = C.at(0);

    auto test = machine.sk.decrypt(C[0], setup.FieldD);
    for (int i = 0; i < setup.FieldD.num_slots(); i++)
        assert(test.element(i) == mac_key);

    cerr << "MAC key generation took " << timer.elapsed() << " seconds" << endl;
}
