#pragma once

#include <emp-zk/emp-zk.h>
#include <emp-tool/emp-tool.h>
#include <iostream>
#include <unordered_map>
#include <fstream>
#include <string>
#include "source/utils.cpp"

using namespace emp;
using namespace std;


vector<int> parse_node_file(string filename, size_t & tree_depth_out) {
    cout << "parse_node_file: " << filename << endl;
    ifstream infile(filename);
    vector<int> ret;
    int x = -1;
    while (infile >> x) {
        ret.push_back(x);
    }

    // infer depth of tree
    size_t nn_sz = ret.size();
    int greatest_node = ret[nn_sz-1]; // largest value in node_names
    tree_depth_out = floor(log2(greatest_node)); // all nodes on the nth level are less than 2^n

    return ret;
}

unordered_map<int, int> parse_leaf_files(string leafnames_file, string leafvalues_file) {
    cout << "parse_leaf_files: " << leafnames_file << "  " << leafvalues_file << endl;
    ifstream names_in(leafnames_file);
    ifstream vals_in(leafvalues_file);
    int x;
    int val;
    unordered_map<int, int> ret;
    while( !names_in.eof() && !vals_in.eof() ){
        names_in>>x;  
        vals_in>>val;
        ret[x] = val;
    }
    return ret;
}


void interactive_seed_generation(BoolIO<NetIO> *io, int party, bool * alice_randomness, bool * bob_randomness, size_t rand_sz, vector<Bit> & out_seed) {
    
    vector<Bit> alice_com_r;
    for (int i=0; i<rand_sz; ++i) {
        alice_com_r.push_back(Bit(alice_randomness[i],ALICE));
    }
    bool s_in_bits[rand_sz];
    
    if (party == ALICE) {
        io->recv_bool(s_in_bits, rand_sz);
    } else {
        for (int i=0; i<rand_sz; ++i) {
            s_in_bits[i] = bob_randomness[i];
        }
        io->send_bool(bob_randomness, rand_sz);
        io->flush();
    }
    vector<Bit> bob_com_r;
    for (int i=0; i<rand_sz; ++i) {
        bob_com_r.push_back(Bit(s_in_bits[i], PUBLIC));
    }
    for (int i=0; i<rand_sz; ++i) {
        out_seed[i] = alice_com_r[i] ^ bob_com_r[i];
    }
}

/* // TODO: FINISH
void batched_interactive_seed_generation(BoolIO<NetIO> *io, int party, bool * alice_randomness, bool * bob_randomness, size_t rand_sz, size_t n_rand, vector< vector<Bit> > & out_seeds) {
    vector<Bit> alice_com_r;
    for (int i=0; i<rand_sz; ++i) {
        alice_com_r.push_back(Bit(alice_randomness[i],ALICE));
    }
    bool s_in_bits[rand_sz*n_rand];
    
    if (party == ALICE) {
        io->recv_bool(s_in_bits, rand_sz*n_rand);
    } else {
        for (int i=0; i<rand_sz; ++i) {
            s_in_bits[i] = bob_randomness[i];
        }
        io->send_bool(bob_randomness, rand_sz);
        io->flush();
    }
    vector<Bit> bob_com_r;
    for (int i=0; i<rand_sz; ++i) {
        bob_com_r.push_back(Bit(s_in_bits[i], PUBLIC));
    }
    for (int i=0; i<rand_sz; ++i) {
        out_seed[i] = alice_com_r[i] ^ bob_com_r[i];
    }
}
*/



// initialize select_sequence given:
// node_names -- node names in sorted order
// tree_sz -- the number of nodes in the tree
// leaf_vals -- dictionary mapping node names to support values of the distribution
// returns select_sequence, a vector of nodes s.t. if selects are performed on these nodes in order it results in knuth-yao sampling
vector<int> init_select_sequence(vector<int> & node_names, size_t tree_sz, unordered_map<int, int> & leaf_vals) {
    // we need to perform selects according to a reversed topological ordering of the nodes, w/o leaves
    // node_names is already in topological order, so:
    vector<int> select_sequence;
    for (int i=0; i<tree_sz; ++i) {
        // filter out leaves
        if (leaf_vals.count(node_names[i])==0) {
            select_sequence.push_back(node_names[i]);
        }
    }
    // reverse select_sequence
    reverse(select_sequence.begin(), select_sequence.end());
    return select_sequence;
}

// returns a vector of authenticated Integers such that leaf values are positioned properly for tree_sample
vector<Integer> init_val(size_t tree_sz, unordered_map<int, int> & leaf_vals, unordered_map<int, int> & m, size_t INT_LEN=32) {
    vector<Integer> val;
    // first initialize all values to publicly authenticated -1
    // add an extra space since we want to be 1-indexing with our node naming scheme
    #ifdef DEBUG
    cout << "init_val: starting initialization\n";
    #endif
    for (int i=0; i<=tree_sz; ++i) {
        val.push_back(Integer(INT_LEN,-1, PUBLIC));
    }
    #ifdef DEBUG
    cout << "init_val: starting loop through leaf_vals\n";
    cout << "leaf_vals: ";
    print_map(leaf_vals);
    size_t i = 0;
    #endif
    // for each
    for (auto& it: leaf_vals) {
        #ifdef DEBUG
        cout << "init_val: leaf_vals iteration -- " << i << "\n";
        i++;
        #endif
        val[m[it.first]] = Integer(INT_LEN, leaf_vals[it.first], PUBLIC);
    }
    
    return val;
}

// returns a mapping of node names to indices in val
unordered_map<int, int> init_m(vector<int> & node_names, size_t tree_sz) {
    unordered_map<int, int> m;
    for (int i=0; i<tree_sz; ++i) {
        m[node_names[i]] = i+1; // (simply the order in  which they are encountered)
    }
    return m;
}

// given a knuth-yao tree in the form: 
// node_names -- integer node 'names' in sorted order (such that the root is named 1, and \forall i, left child is 2i, right child is 2i+1)
// leaf_vals -- maps node names of leaves to the support values they hold in the knuth-yao tree (expressed as integers/fixed point numbers)
// and references to empty containers to hold outputs
// initialize:
// val -- an array the size of the tree
// select_sequence -- a valid sequence of nodes in the tree to select over in order to execute sampling
// m -- maps node names to indices of val
void ts_init(vector<int> & node_names, unordered_map<int, int> & leaf_vals, vector<Integer> & val_out, vector<int> & select_sequence_out, unordered_map<int, int> & m_out, size_t INT_LEN=32) {

    select_sequence_out = init_select_sequence(node_names, node_names.size(), leaf_vals);
    #ifdef DEBUG
    cout << "init_select_sequence finished\n";
    #endif
    m_out = init_m(node_names, node_names.size());
    #ifdef DEBUG
    cout << "init_m finished\n";
    #endif
    #ifdef DEBUG
    cout << "ts_init start\n";
    cout << "node_names.size(): " << node_names.size();
    #endif
    val_out = init_val(node_names.size(), leaf_vals, m_out, INT_LEN);
    #ifdef DEBUG
    cout << "init_val finished\n";
    #endif
}

// initialized by ts_init()
// val -- initialized vector of values in the tree
// r -- random bit string for sampling
// select_sequence -- in order set of selects required for the sampling process
// num_selects -- size of select_sequence
// m -- maps node names to indices in val
Integer tree_sample(vector<Integer> & val, vector<Bit> & r, vector<int> & select_sequence, size_t num_selects, unordered_map<int, int> & m) {
    for (size_t i=0; i<num_selects; ++i) {
        int n = select_sequence[i]; // node name of each select
        int layer = floor(log2(n)); // property of node name encoding
        // property of node name encoding: for node with name i, left child name is 2i and right child name is 2i+1
        Integer left_child = val[m[n*2]]; 
        Integer right_child = val[m[n*2+1]];
        val[m[n]] = left_child.select(r[layer], right_child);
    }

    return val[m[1]];
}