#pragma once

#include <emp-zk/emp-zk.h>
#include <iostream>
#include <emp-tool/emp-tool.h>
#include "utils.cpp"
#include "sample.cpp"

using namespace emp;
using namespace std;

Float _sigmoid(Float z) {
    Float ONE = Float(1, PUBLIC);
    Float ret = ONE / (ONE + (-z).exp());
    return ret;
}

Float _linear(vector<Float> & input, vector<Float> & weights) {
    Float ZERO = Float(0, PUBLIC);
    Float ret = inner_product(begin(weights), end(weights), begin(input), ZERO);
    return ret;
}

// outputs a number \in [0,1]
Float binary_LR_inference(vector<Float> & input, vector<Float> & weights) {
    Float logit = _linear(input, weights);
    Float ret = _sigmoid(logit);
    return ret;
}

Float LR_logit(vector<Float> & input, vector<Float> & weights) {
    Float logit = _linear(input, weights);
    return logit;
}

// placeholder, returns all 1 weights
// runtime of the ZKP parts of the verification do not vary based on parameter values, so this still provides accurate benchmarks
void auth_updated_weights(vector<Float> & out_weights, size_t num_params) {
    out_weights = Float_ones(num_params);
}


// \nabla f (w, (x,y)) = (y~ - y) x, where y~ is the predicted label from the binary LR 
// assume label is between 1 and 0. Can be checked easily elsewhere in the program, but it's not checked by this function.
// passes results by modifying out_grad
void single_data_grad(vector<Float> & x, vector<Float> & weights, Float & label, vector<Float> & out_grad, size_t sz_outgrad) {
    Float predicted_label = binary_LR_inference(x, weights);
    Float s = predicted_label - label;
    for (int i=0; i<sz_outgrad; ++i) {
        out_grad[i] = x[i] * s;
    }
}

// CONST_REG_MULT: 1/ \nu_i n_i computed in the clear and committed publicly so that we don't have to do the division inside ZKP
void regularization_term(vector<Float> & w_tilde_i, vector<Float> & w_i_minus_one, Float CONST_REG_MULT, vector<Float> & out_reg, size_t out_sz) {
    for (int i=0; i<out_sz; ++i) {
        out_reg[i] = CONST_REG_MULT * (w_tilde_i[i] - w_i_minus_one[i]);
    }
}

// compute \nabla F_i (w~_i, D_i) = 1/n_i \sum_{x \in D_i} \nabla f (w~_i, x) + 1/ n_i \nu_i (w~_i - w_{i-1})
// then check that ||\nabla F_i(w~_i, D_i) ||_2 <= 2/n_i L
// constants computed in the clear and committed publicly so we don't have to compute them in ZKP:
// CONST_REG_MULT: 1/ \nu_i n_i 
// CONST_GRADSUM_MULT: 1/n_i
// CONST_GRADBOUND_SQUARED: (2/n_i L)^2     (squaring avoids having to take square root for L2 norm)
void phased_ERM_grad_check(vector< vector<Float> > & D, vector<Float> & Y, vector<int> & Di_indices, size_t batch_sz, vector<Float> & w_tilde_i, vector<Float> & w_i_minus_one, Float CONST_REG_MULT, Float CONST_GRADSUM_MULT, Float CONST_GRADBOUND_SQUARED, vector<Float> & out_grad, size_t sz_outgrad) {
    // initialize a temp vector
    vector<Float> temp_grad;
    vector<Float> temp_x;
    vector<Float> temp_reg;
    for (int i=0; i<sz_outgrad; ++i) {
        temp_grad.push_back(Float(0.0, PUBLIC));
        temp_x.push_back(Float(0.0, PUBLIC));
        temp_reg.push_back(Float(0.0, PUBLIC));
    }

    // summation of \nabla f (w~_i, x) for each x in D_i
    for (int i=0; i<batch_sz; ++i) {
        if (i == 0 || i == 5 || i == 50 || i == 250 || i == 1000 || i % 2500 == 0 || i == batch_sz-1) {
            cout << "\t phased_ERM_grad_check -- i: " << i << "/" << batch_sz << endl;
        }
        int Di_ind = Di_indices[i];
        temp_x = D[Di_ind]; // unpacking datapoint x from committed dataset D
        Float y = Y[Di_ind]; // unpacking label y
        single_data_grad(temp_x, w_tilde_i, y, temp_grad, sz_outgrad);
        for (int j=0; j<sz_outgrad; ++j) {
            out_grad[j] = out_grad[j] + temp_grad[j]; // summing up the outputs
        }
    }
    // element-wise multiply by 1/n_i
    for (int i=0; i<sz_outgrad; ++i) {
        out_grad[i] = out_grad[i] * CONST_GRADSUM_MULT;
    }
    // add regularization term
    regularization_term(w_tilde_i, w_i_minus_one, CONST_REG_MULT, temp_reg, sz_outgrad);
    for (int i=0; i<sz_outgrad; ++i) {
        out_grad[i] = out_grad[i] + temp_reg[i];
    }

    // compute L2 norm bound || \nabla F_i(w~_i, w_i-1, D_i) ||_2 <= 2/n_i L
    // equivalent check which doesn't have to take a square root:
    // || \nabla F_i(w~_i, w_i-1, D_i) ||_2 ^2 <= (2/n_i L)^2
    Float sum(0.0, PUBLIC);
    for (int i=0; i<sz_outgrad; ++i) {
        sum = sum + (out_grad[i] * out_grad[i]);
    }
    cout << "\t phased_ERM_grad_check -- finished" << endl;
    
    //Bit b = (sum.less_equal(CONST_GRADBOUND_SQUARED));
    //cout << "Gradient Check: " << b.reveal() << endl;
}


void phased_ERM(vector< vector<Float> > & D, vector<Float> & Y, double L, vector<Float> & w_init, double step_size, size_t num_params, size_t rand_bits_sz, vector<int> tree_node_names, unordered_map<int, int> leaf_vals, BoolIO<NetIO> *rand_io, int party) {
    size_t n = D.size();
    size_t k = ceil(log2(n * 1.0));
    vector<Float> w_tilde_i;
    vector<Float> w_i_minus_one;
    vector<Float> out_grad;
    vector<Bit> r;
    //vector<Float> noise_vec;
    vector<Integer> fp_noise_vec;

    bool alice_local_rand[rand_bits_sz];
    bool bob_local_rand[rand_bits_sz];

    // initialize params to 0
    for (int i=0; i<num_params; ++i) {
        w_tilde_i.push_back(Float(0.0, PUBLIC));
        w_i_minus_one.push_back(Float(0.0, PUBLIC));
        out_grad.push_back(Float(0.0, PUBLIC));
        fp_noise_vec.push_back(Integer(32,0, PUBLIC));
    }
    for (int i=0; i<rand_bits_sz; ++i) {
        r.push_back(Bit(0, PUBLIC));
    }

    // initialize tree sampling
    vector<Integer> tree_val;
    vector<int> tree_select_seq;
    unordered_map<int, int> tree_m;
    ts_init(tree_node_names, leaf_vals, tree_val, tree_select_seq, tree_m);

    size_t prev_ni = 0;
    for (int i=0; i<k; ++i) {
        size_t n_i = (n-(prev_ni-1))/2;
        double nu_i = step_size / (pow(4,i));
        //size_t batch_sz = n / (1 << i);
        vector<int> Di_indices = indices_range(prev_ni, prev_ni+n_i);
        double reg_mult = 1.0 / (n_i * nu_i); 
        double gradbound_squared = exp2(2.0 / n_i * L);
        Float Float_recip_ni(1.0/n_i, PUBLIC);
        Float Float_rm(reg_mult, PUBLIC);
        Float Float_gbs(gradbound_squared, PUBLIC);
        
        phased_ERM_grad_check(D, Y, Di_indices, Di_indices.size(), w_tilde_i, w_i_minus_one, Float_rm, Float_recip_ni, Float_gbs, out_grad, num_params);
        
        for (int j=0; j<num_params; ++j) {
            random_bits(alice_local_rand, rand_bits_sz);
            random_bits(bob_local_rand, rand_bits_sz);
            interactive_seed_generation(rand_io, party, alice_local_rand, bob_local_rand, rand_bits_sz, r);
            fp_noise_vec[j] = tree_sample(tree_val, r, tree_select_seq, tree_select_seq.size(), tree_m);
        }


        prev_ni = n_i;
        for (int i=0; i<num_params; ++i) {
            w_i_minus_one[i] = fp_to_float(float_to_fp(w_tilde_i[i]) + fp_noise_vec[i]);
        }
        auth_updated_weights(w_tilde_i, num_params); // placeholder for local training
    }
}

// n is dataset size for number of training examples, independent of number of training examples committed at a time 
void phased_ERM_multinoise(vector< vector<Float> > & D, vector<Float> & Y, double L, vector<Float> & w_init, double step_size, size_t num_params, vector<size_t> rand_bits_szes, vector< vector<int> > tree_node_nameses, vector< unordered_map<int, int> > leaf_valses, size_t n, BoolIO<NetIO> *rand_io, int party) {
    size_t k = ceil(log2(n * 1.0));
    vector<Float> w_tilde_i;
    vector<Float> w_i_minus_one;
    vector<Float> out_grad;

    //vector<Float> noise_vec;
    vector<Integer> fp_noise_vec;

    // initialize params to 0
    for (int i=0; i<num_params; ++i) {
        w_tilde_i.push_back(Float(0.0, PUBLIC));
        w_i_minus_one.push_back(Float(0.0, PUBLIC));
        out_grad.push_back(Float(0.0, PUBLIC));
        fp_noise_vec.push_back(Integer(32,0, PUBLIC));
    }

    // initialize tree sampling
    vector<Integer> tree_val;
    vector<int> tree_select_seq;
    unordered_map<int, int> tree_m;

    size_t prev_ni = 1;
    for (int i=1; i<=k; ++i) {
        if (i>= tree_node_nameses.size() || i>= leaf_valses.size() || i >= rand_bits_szes.size()) {
            cout << "phased_ERM_multinoise -- batch: " << i << " ERROR: mismatch in phased ERM rounds and noise inputs.";
            return;
        }
        ts_init(tree_node_nameses[i], leaf_valses[i], tree_val, tree_select_seq, tree_m);
        //size_t n_i = n / (1 << i);
        size_t n_i = (n-(prev_ni-1))/2;
        double nu_i = step_size / (pow(4,i));
        //size_t batch_sz = n / (1 << i);
        vector<int> Di_indices = pseudo_indices_range(prev_ni, prev_ni+n_i);
        #ifdef DEBUG
        cout << "phased_ERM_multinoise -- batch " << i << ":  Di_indices.size() " << Di_indices.size() << endl;
        cout << "prev_ni: " << prev_ni << "  n_i: " << n_i << endl;
        #endif
        double reg_mult = 1.0 / (n_i * nu_i); 
        double gradbound_squared = exp2(2.0 / n_i * L);
        cout << "phased_ERM_multinoise -- batch: " << i << "/" << k << "(" << Di_indices.size() << " examples)" << endl;
        Float Float_recip_ni(1.0/n_i, PUBLIC);
        Float Float_rm(reg_mult, PUBLIC);
        Float Float_gbs(gradbound_squared, PUBLIC);
        
        phased_ERM_grad_check(D, Y, Di_indices, Di_indices.size(), w_tilde_i, w_i_minus_one, Float_rm, Float_recip_ni, Float_gbs, out_grad, num_params);
        
        cout << "phased_ERM_multinoise -- batch: " << i << "/" << k << "  noise sampling start" << endl;
        size_t rand_bits_sz = rand_bits_szes[i];
        vector<Bit> r;
        for (int j=0; j<rand_bits_sz; ++j) {
            r.push_back(Bit(0, PUBLIC));
        }
        bool alice_local_rand[rand_bits_sz];
        bool bob_local_rand[rand_bits_sz];
        for (int j=0; j<num_params; ++j) {
            //cout << "seed gen -- " << j << endl;
            random_bits(alice_local_rand, rand_bits_sz);
            random_bits(bob_local_rand, rand_bits_sz);
            interactive_seed_generation(rand_io, party, alice_local_rand, bob_local_rand, rand_bits_sz, r);
            //cout << "tree sample -- " << j << endl;
            fp_noise_vec[j] = tree_sample(tree_val, r, tree_select_seq, tree_select_seq.size(), tree_m);
        }
        cout << "phased_ERM_multinoise -- batch: " << i << "/" << k << "  noise sampling end" << endl;

        prev_ni = prev_ni + n_i;
        for (int i=0; i<num_params; ++i) {
            w_i_minus_one[i] = fp_to_float(float_to_fp(w_tilde_i[i]) + fp_noise_vec[i]);
        }
        auth_updated_weights(w_tilde_i, num_params); // placeholder for local training
    }
}