#ifndef EMP_DPF_FLORAM_H__
#define EMP_DPF_FLORAM_H__

#include "2pc-backend/ccrh.h"
#include "2pc-backend/utils.h"
#include "2pc-backend/prg.h"
#include "emp-tool/emp-tool.h"
#include "emp-ot/emp-ot.h"
#include <emp-tool/utils/block.h>

namespace emp
{

    class GbWire
    {
    public:
        block L;

        GbWire(const block &L = zero_block) : L(L) {}
    };

    template <typename T>
    class HalfGen : public Backend
    {
    public:
        T *io;
        COT<T> *ot;
        MITCCRH<8> mitccrh;

        block constant[2];
        block delta;
        uint64_t ands = 0;

        HalfGen(int party, T *io, COT<T> *ot, block delta) : Backend(party), io(io), ot(ot), delta(delta)
        {
            PRG().random_block(constant, 2);
            this->io->send_block(constant, 2);
            constant[1] = constant[1] ^ delta;
            mitccrh.setS(constant[0]);
        }

        inline void HalfGateGen(const block &LA0, const block &LB0, block *LW0)
        {
            block table[2];
            bool pa = getLSB(LA0);
            bool pb = getLSB(LB0);
            block HLA0, HA1, HLB0, HB1, W0;
            block tmp;
            block H[4];
            H[0] = LA0;
            H[1] = LA0 ^ delta;
            H[2] = LB0;
            H[3] = LB0 ^ delta;
            mitccrh.hash<2, 2>(H);
            HLA0 = H[0];
            HA1 = H[1];
            HLB0 = H[2];
            HB1 = H[3];
            table[0] = HLA0 ^ HA1;
            table[0] = table[0] ^ (select_mask[pb] & delta);
            W0 = HLA0;
            W0 = W0 ^ (select_mask[pa] & table[0]);
            tmp = HLB0 ^ HB1;
            table[1] = tmp ^ LA0;
            W0 = W0 ^ HLB0;
            W0 = W0 ^ (select_mask[pb] & tmp);
            *LW0 = W0;
            io->send_block(table, 2);
        }

        void xor_gate(void *out, const void *left, const void *right) override
        {
            ((GbWire *)out)->L = ((GbWire *)left)->L ^ ((GbWire *)right)->L;
        }

        void not_gate(void *out, const void *in) override
        {
            ((GbWire *)out)->L = ((GbWire *)in)->L ^ constant[1];
        }

        void and_gate(void *out, const void *left, const void *right) override
        {
            ands++;
            HalfGateGen(((GbWire *)left)->L, ((GbWire *)right)->L, &(((GbWire *)out)->L));
        }

        void feed(void *lbls, int party, const bool *b, size_t nel) override
        {
            GbWire *out = (GbWire *)lbls;
            if (party == PUBLIC)
            {
                for (size_t i = 0; i < nel; ++i)
                {
                    ((GbWire *)lbls)[i].L = constant[b[i]];
                }
            }
            else
            {
                block *data = new block[nel];
                if (this->party == ALICE)
                {
                    if (this->party == party)
                    {
                        ot->send_cot(data, nel);
                        for (size_t i = 0; i < nel; i++)
                        {
                            out[i].L = data[i] ^ (b[i] ? delta : zero_block);
                        }
                    }
                    else
                    {
                        ot->send_cot(data, nel);
                        for (size_t i = 0; i < nel; i++)
                        {
                            out[i].L = data[i];
                        }
                    }
                }
            }
        }

        void reveal(bool *out, int party, const void *lbls, size_t nel) override
        {
            const GbWire *in = (const GbWire *)lbls;
            bool shr[nel], ret[nel];
            for (size_t i = 0; i < nel; ++i)
            {
                shr[i] = getLSB(in[i].L);
            }
            if (this->party == ALICE)
            {
                io->send_bool(shr, nel);
                io->recv_bool(ret, nel);
            }
            else
            {
                io->recv_bool(ret, nel);
                io->send_bool(shr, nel);
            }
            for (size_t i = 0; i < nel; ++i)
            {
                out[i] = shr[i] ^ ret[i];
            }
        }

        uint64_t num_and() override
        {
            return ands;
        }
    };

    template <typename T>
    class HalfEva : public Backend
    {
    public:
        T *io;
        COT<T> *ot;
        MITCCRH<8> mitccrh;

        block constant[2];
        uint64_t ands = 0;

        HalfEva(int party, T *io, COT<T> *ot) : Backend(party), io(io), ot(ot)
        {
            this->io->recv_block(constant, 2);
            mitccrh.setS(constant[0]);
        }

        inline void HalfGateEva(const block &A, const block &B, block *W)
        {
            block table[2];
            io->recv_block(table, 2);
            block HA, HB;
            int sa, sb;
            sa = getLSB(A);
            sb = getLSB(B);
            block H[2];
            H[0] = A;
            H[1] = B;
            mitccrh.hash<2, 1>(H);
            HA = H[0];
            HB = H[1];
            *W = HA ^ HB;
            *W = *W ^ (select_mask[sa] & table[0]);
            *W = *W ^ (select_mask[sb] & table[1]);
            *W = *W ^ (select_mask[sb] & A);
        }

        void xor_gate(void *out, const void *left, const void *right) override
        {
            ((GbWire *)out)->L = ((GbWire *)left)->L ^ ((GbWire *)right)->L;
        }

        void not_gate(void *out, const void *in) override
        {
            ((GbWire *)out)->L = ((GbWire *)in)->L ^ constant[1];
        }

        void and_gate(void *out, const void *left, const void *right) override
        {
            ands++;
            HalfGateEva(((GbWire *)left)->L, ((GbWire *)right)->L, &(((GbWire *)out)->L));
        }

        void feed(void *lbls, int party, const bool *b, size_t nel) override
        {
            if (party == PUBLIC)
            {
                for (size_t i = 0; i < nel; ++i)
                {
                    ((GbWire *)lbls)[i].L = constant[b[i]];
                }
            }
            else
            {
                GbWire *out = (GbWire *)lbls;
                if (party == PUBLIC)
                {
                    for (size_t i = 0; i < nel; ++i)
                    {
                        ((GbWire *)lbls)[i].L = constant[b[i]];
                    }
                }
                else
                {
                    block *data = new block[nel];
                    bool *bb = new bool[nel];
                    for (size_t i = 0; i < nel; i++)
                        bb[i] = false;
                    if (this->party == party)
                    {
                        ot->recv_cot(data, b, nel);
                    }
                    else
                    {
                        ot->recv_cot(data, bb, nel);
                    }
                    for (size_t i = 0; i < nel; i++)
                    {
                        out[i].L = data[i];
                    }
                    delete[] bb;
                }
            }
        }

        void reveal(bool *out, int party, const void *lbls, size_t nel) override
        {
            const GbWire *in = (const GbWire *)lbls;
            bool shr[nel], ret[nel];
            for (size_t i = 0; i < nel; ++i)
            {
                shr[i] = getLSB(in[i].L);
            }
            if (this->party == ALICE)
            {
                io->send_bool(shr, nel);
                io->recv_bool(ret, nel);
            }
            else
            {
                io->recv_bool(ret, nel);
                io->send_bool(shr, nel);
            }
            for (size_t i = 0; i < nel; ++i)
            {
                out[i] = shr[i] ^ ret[i];
            }
        }

        uint64_t num_and() override
        {
            return ands;
        }
    };


    template <typename T>
    class MPCBackend : public Backend
    {
    public:
        using Bit = Bit_T<GbWire>;
        using Integer = Integer_T<GbWire>;

        MPCBackend(int party, int threads, T **ios, block delta, int max_size = 24) : Backend(party), ios(ios), threads(threads), delta(delta)
        {
            this->og_thread = threads;
            this->max_size = max_size;
            this->pool = new ThreadPool(threads - 1);
            this->ot_send = new FerretCOT<T>(ALICE, 1, ios, false, false, ferret_b11);
            this->ot_recv = new FerretCOT<T>(BOB, 1, ios, false, false, ferret_b11);
            if (this->party == ALICE)
            {
                this->ot_send->setup(delta, "./data/pre_ot");
                this->ot_recv->setup("./data/pre_ot");
                this->gate_gen = new HalfGen<T>(party, ios[0], ot_send, delta);
            }
            else
            {
                this->ot_recv->setup("./data/pre_ot");
                this->ot_send->setup(delta, "./data/pre_ot");
                this->gate_eva = new HalfEva<T>(party, ios[0], ot_recv);
            }

            // Initialize DPF
            prp = new DPFPRG *[threads];
            ccrh = new DualCCRH[threads];
            for (int i = 0; i < threads; i++)
            {
                ccrh[i] = DualCCRH(zero_block);
                prp[i] = new DPFPRG(zero_block, makeBlock(0, 1));
            }
            tree = new block[1 << (max_size + 1)];
            next = new block[1 << (max_size + 1)];
            tree_bit = new bool[1 << max_size];
            for (int i = 0; i < (1 << max_size); i++)
            {
                tree_bit[i] = false;
            }
        }

        void switch_to_gt()
        {
            if (this->party == ALICE)
            {
                emp::backend = this->gate_gen;
            }
            else
            {
                emp::backend = this->gate_eva;
            }
        }

        void switch_back()
        {
            emp::backend = this;
        }

        void xor_gate(void *out, const void *left, const void *right) override
        {
            throw std::runtime_error("gates execution should be done by exec function");
        }

        void not_gate(void *out, const void *in) override
        {
            throw std::runtime_error("gates execution should be done by exec function");
        }

        void and_gate(void *out, const void *left, const void *right) override
        {
            throw std::runtime_error("gates execution should be done by exec function");
        }

        void feed(void *lbls, int party, const bool *b, size_t nel) override
        {
            GbWire *out = (GbWire *)lbls;
            if (party == PUBLIC)
            {
                if (this->party == ALICE)
                {
                    this->gate_gen->feed(lbls, party, b, nel);
                }
                else
                {
                    this->gate_eva->feed(lbls, party, b, nel);
                }
            }
            else
            {
                block *data = new block[nel];
                if (this->party == ALICE)
                {
                    if (this->party == party)
                    {
                        ot_send->send_cot(data, nel);
                        for (size_t i = 0; i < nel; i++)
                        {
                            out[i].L = data[i] ^ (b[i] ? delta : zero_block);
                        }
                    }
                    else
                    {
                        ot_send->send_cot(data, nel);
                        for (size_t i = 0; i < nel; i++)
                        {
                            out[i].L = data[i];
                        }
                    }
                }
                else
                {
                    bool *bb = new bool[nel];
                    for (size_t i = 0; i < nel; i++)
                        bb[i] = false;
                    if (this->party == party)
                    {
                        ot_recv->recv_cot(data, b, nel);
                    }
                    else
                    {
                        ot_recv->recv_cot(data, bb, nel);
                    }
                    for (size_t i = 0; i < nel; i++)
                    {
                        out[i].L = data[i];
                    }
                    delete[] bb;
                }
                delete[] data;
            }
        }

        void reveal(bool *out, int party, const void *lbls, size_t nel) override
        {
            const GbWire *in = (const GbWire *)lbls;
            bool shr[nel], ret[nel];
            for (size_t i = 0; i < nel; ++i)
            {
                shr[i] = getLSB(in[i].L);
            }
            if (this->party == ALICE)
            {
                ios[0]->send_bool(shr, nel);
                ios[0]->recv_bool(ret, nel);
            }
            else
            {
                ios[0]->recv_bool(ret, nel);
                ios[0]->send_bool(shr, nel);
            }
            for (size_t i = 0; i < nel; ++i)
            {
                out[i] = shr[i] ^ ret[i];
            }
        }

        uint64_t num_and() override
        {
            if (this->party == ALICE)
            {
                return this->gate_gen->num_and();
            }
            else
            {
                return this->gate_eva->num_and();
            }
        }

        template <typename Func, typename... Args>
        void execute(Func &&func, Args &&...args)
        {
            if (this->party == ALICE)
            {
                emp::backend = this->gate_gen;
                func(std::forward<Args>(args)...);
                //                emp::backend = this;
            }
            else
            {
                emp::backend = this->gate_eva;
                func(std::forward<Args>(args)...);
                //                emp::backend = this;
            }
        }

        ~MPCBackend()
        {
            delete this->ot_send;
            delete this->ot_recv;
            delete this->gate_gen;
            delete this->gate_eva;
            delete[] ccrh;
            for (int i = 0; i < this->threads; i++)
                delete prp[i];
            delete[] prp;
            delete[] this->tree;
            delete[] this->next;
            delete[] this->tree_bit;
            delete this->pool;

            std::remove("./data/pre_ot");
        }

        // private:
        // Execution & OT
        T **ios;
        int threads;
        int og_thread;
        ThreadPool *pool;
        block delta;
        Backend *gate_gen, *gate_eva;
        FerretCOT<T> *ot_send, *ot_recv;
        // DPF
        DPFPRG **prp;
        DualCCRH *ccrh;
        block *tree, *next;
        bool *tree_bit;
        int DPF_counter = 0;
        int max_size = 0;

        inline void io_flush(int seq)
        {
            ios[seq]->flush();
        }

        inline void all_flush()
        {
            for (int i = 0; i < this->threads; i++)
                ios[i]->flush();
        }

    };


}

#endif