#ifndef ZKDT_DT_BATCH_GADGET_H
#define ZKDT_DT_BATCH_GADGET_H

#include "DT/DT.h"
#include "gadgets/common.h"
#include "gadgets/tool_gadgets.h"
#include "gadgets/swifft.h"

#include <boost/math/distributions/students_t.hpp>

void print_node_vector(const std::vector<DTNode*>& nodes) {
    for (const auto* node : nodes) {
        node->print_info();
    }
}

double get_t_critical(double alpha, double v0, double v1, int n0, int n1) {
    // Welch–Satterthwaite equation for degrees of freedom
    double numerator = (v0 / n0 + v1 / n1) * (v0 / n0 + v1 / n1);
    double denominator = ( (v0 / n0) * (v0 / n0) ) / (n0 - 1)
                       + ( (v1 / n1) * (v1 / n1) ) / (n1 - 1);
    double df = numerator / denominator;

    // Two-tailed test: use alpha/2
    boost::math::students_t dist(df);
    double t_critical = quantile(complement(dist, alpha / 2));
    return t_critical;
}

bool compare_means(const double mean0, const double mean1, 
    const double var0, const double var1, const int n0, const int n1,
    double alpha) {

    // t-statistic for two-sample (Welch's t-test)
    double t_stat = (mean0 - mean1) / std::sqrt(var0 / n0 + var1 / n1);

    // Get critical t-value
    double t_crit = get_t_critical(alpha, var0, var1, n0, n1);

    std::cout << "t-statistic: " << t_stat << std::endl;
    std::cout << "t-critical: " << t_crit << std::endl;

    if (std::abs(t_stat) > t_crit) {
        std::cout << "Means are significantly different at alpha = " << alpha << std::endl;
        return false;
    } else {
        std::cout << "No significant difference at alpha = " << alpha << std::endl;
        return true;
    }
}

template<typename FieldT>
class PathPredictionGadget : public gadget<FieldT> {

public:
    // public input
    pb_variable <FieldT> *raw_vars;
    pb_variable <FieldT> *raw_index;
    pb_variable <FieldT> coef, challenge_point, target_class_label, positive_class_label; // added positive label info

    // extended witness: permutation
    pb_variable <FieldT> *permuted_vars;
    pb_variable <FieldT> *permuted_index;

    // secret: v1, v2, ... vh
    pb_variable <FieldT> *node_id;
    pb_variable <FieldT> *variable_id;
    pb_variable <FieldT> *threshold;
    pb_variable <FieldT> *l_node_id;
    pb_variable <FieldT> *r_node_id;
    pb_variable <FieldT> class_label; // prediction outcome

    // helper variables
    PairwiseMultiSetGadget <FieldT> *pairwiseMultiSetGadget;
    pb_variable <FieldT> *frequency_in_bits;

    ComparisonGadget <FieldT> *comparisonGadgets;
    pb_variable <FieldT> *comparison_results;
    pb_variable <FieldT> *diff;
    pb_variable <FieldT> *diff_decomposition;

    DecompositionCheckGadget <FieldT> *decompositionCheckGadgets;
    EqualityCheckGadget <FieldT> *equalityCheckGadget;
    EqualityCheckGadget <FieldT> *equalityCheckGadgetP;

    pb_variable <FieldT> correct;
    // stores whether the outcome is positive or not
    pb_variable <FieldT> positive;

    void _init_pb_vars() {
        auto &prefix = this->annotation_prefix;

        // raw_vars, raw_index, coef, challenge_point, commitment_in_bits
        _init_pb_array(this->pb, raw_vars, n_vars, prefix + std::string("raw_vars"));
        _init_pb_array(this->pb, raw_index, n_vars, prefix + std::string("raw_index"));

        // permuted_vars, permuted_index
        _init_pb_array(this->pb, permuted_vars, path_length - 1, prefix + std::string("permuted_vars"));
        _init_pb_array(this->pb, permuted_index, path_length - 1, prefix + std::string("permuted_index"));

        // node_id, threshold, l_node_id, r_node_id, class_label, pos_label
        _init_pb_array(this->pb, node_id, path_length, prefix + std::string("node_id"));
        _init_pb_array(this->pb, variable_id, path_length - 1, prefix + std::string("variable_id"));
        _init_pb_array(this->pb, threshold, path_length - 1, prefix + std::string("threshold"));
        _init_pb_array(this->pb, l_node_id, path_length - 1, prefix + std::string("l_node_id"));
        _init_pb_array(this->pb, r_node_id, path_length - 1, prefix + std::string("r_node_id"));
        class_label.allocate(this->pb, prefix + std::string("class_label"));
        target_class_label.allocate(this->pb, prefix + std::string("target_class_label"));
        positive_class_label.allocate(this->pb, prefix + std::string("positive_class_label"));

        _init_pb_array(this->pb, frequency_in_bits, n_vars * _n_frequency_bits,
                       prefix + std::string("frequency_in_bits"));

        // comparison_results, hash_inputs, hash_outputs
        _init_pb_array(this->pb, comparison_results, path_length - 1, prefix + std::string("comparison_results"));
        _init_pb_array(this->pb, diff, path_length - 1, prefix + std::string("diff"));
        _init_pb_array(this->pb, diff_decomposition, (path_length - 1) * 32,
                       prefix + std::string("diff_decomposition"));

        correct.allocate(this->pb, prefix + std::string("correct"));
        positive.allocate(this->pb, prefix + std::string("positive"));
    }

    void _decompose(unsigned value, pb_variable <FieldT> *bits) {
        for (int i = 0; i < 32; ++i) {
            eval(bits[i]) = (value >> (31 - i)) & 1U;
        }
    }

    void _multiset_generate_r1cs_constraints() {
        pairwiseMultiSetGadget->generate_r1cs_constraints();
    }

    void _multiset_generate_r1cs_witness() {
        for (int i = 0; i < path_length - 1; ++i) {
            unsigned _index = ((DTInternalNode *) path[i])->variable_id;
            eval(permuted_index[i]) = _index;
            eval(permuted_vars[i]) = values[_index];
        }

        for (int i = 0; i < n_vars; ++i) {
            for (int j = 0; j < _n_frequency_bits; ++j) {
                eval(frequency_in_bits[i * _n_frequency_bits + j]) =
                        (_variable_count[i] >> (_n_frequency_bits - j - 1)) & 1U;
            }
        }

        pairwiseMultiSetGadget->generate_r1cs_witness();
    }

    void _bit_decomposition_generate_r1cs_constraints() {
        decompositionCheckGadgets[0].generate_r1cs_constraints();
    }

    void _bit_decomposition_generate_r1cs_witness() {
        decompositionCheckGadgets[0].generate_r1cs_witness();
    }

    void _dt_prediction_generate_r1cs_constraints() {
        for (int i = 0; i < path_length - 1; ++i) {
            comparisonGadgets[i].generate_r1cs_constraints();
            add_r1cs(permuted_index[i], 1, variable_id[i]);
            add_r1cs(comparison_results[i], l_node_id[i] - r_node_id[i], node_id[i + 1] - r_node_id[i]);
            // 1 (v_{i+1}.node_id - v_{i}.l_node_id) + 0 (v_{i+1}.node_id - v_{i}.r_node_id) = 0;
        }
        equalityCheckGadget->generate_r1cs_constraints();
        equalityCheckGadgetP->generate_r1cs_constraints();
    }

    void _dt_prediction_generate_r1cs_witness() {
        for (int i = 0; i < path_length - 1; ++i) {
            eval(comparison_results[i]) = path[i + 1]->is_left();
            unsigned x = values[((DTInternalNode *) path[i])->variable_id], y = ((DTInternalNode *) path[i])->threshold;
            unsigned d = (x <= y) ? (y - x) : (x - y);
            _decompose(d, diff_decomposition + 32 * i);
            eval(diff[i]) = d;
        }

        for (int i = 0; i < path_length - 1; ++i) {
            comparisonGadgets[i].generate_r1cs_witness();
        }

        equalityCheckGadget->generate_r1cs_witness(target_label, predicted_label);
        equalityCheckGadgetP->generate_r1cs_witness(positive_label, predicted_label);
    }

    void _general_generate_r1cs_witness() {
        for (int i = 0; i < n_vars; ++i) {
            eval(raw_vars[i]) = values[i];
            eval(raw_index[i]) = i;
        }

        for (int i = 0; i < path_length - 1; ++i) {
            DTInternalNode *node = (DTInternalNode *) path[i];
            eval(node_id[i]) = node->node_id;
            eval(variable_id[i]) = node->variable_id;
            eval(threshold[i]) = node->threshold;
            eval(l_node_id[i]) = node->l->node_id;
            eval(r_node_id[i]) = node->r->node_id;
        }

        eval(node_id[path_length - 1]) = ((DTLeaf *) path[path_length - 1])->node_id;
        eval(class_label) = predicted_label;
    }

    unsigned *_variable_count;
    unsigned _n_frequency_bits;

    void _set_frequency_bits() {
        _variable_count = new unsigned[n_vars];
        memset(_variable_count, 0, sizeof(unsigned) * n_vars);
        for (DTNode *node : path) {
            if (!node->is_leaf) {
                _variable_count[((DTInternalNode *) node)->variable_id]++;
            }
        }
        unsigned max_occurence = *std::max_element(_variable_count, _variable_count + n_vars);
        _n_frequency_bits = log(max_occurence) / log(2) + 1;
    }

public:

    DT &dt;
    std::vector<unsigned int> &values;
    std::vector<DTNode *> path;
    int n_vars, path_length;
    unsigned target_label, predicted_label, positive_label;

    PathPredictionGadget(protoboard <FieldT> &pb, DT &dt_, std::vector<unsigned int> &values_,
                         unsigned int target_label_, unsigned int positive_label_,
                         pb_variable <FieldT> &coef_, pb_variable <FieldT> &challenge_point_,
                         const std::string &annotation = "")
            : gadget<FieldT>(pb, annotation), dt(dt_), values(values_), target_label(target_label_), positive_label(positive_label_) {
        path = dt.predict(values_);
        predicted_label = ((DTLeaf*) path.back())->class_id;
        n_vars = values.size();
        path_length = path.size();

        _set_frequency_bits();
        _init_pb_vars();

        // assign some public inputs and randomness here.
        coef = coef_;
        challenge_point = challenge_point_;
        pb.val(target_class_label) = target_label;
        pb.val(positive_class_label) = positive_label;
        pb.val(correct) = unsigned(predicted_label == target_label);
        pb.val(positive) = unsigned(predicted_label == positive_label);

        pairwiseMultiSetGadget = new PairwiseMultiSetGadget<FieldT>(pb, raw_index, raw_vars, permuted_index,
                                                                    permuted_vars,
                                                                    frequency_in_bits, coef, challenge_point, n_vars,
                                                                    path_length - 1, _n_frequency_bits,
                                                                    "pairwise_multiset_gadget");

        comparisonGadgets = (ComparisonGadget <FieldT> *) malloc(
                sizeof(ComparisonGadget < FieldT > ) * (path_length - 1));
        for (int i = 0; i < path_length - 1; ++i) {
            auto gadget_name = annotation + std::string("comparison_gadget_") + std::to_string(i);
            new(comparisonGadgets + i) ComparisonGadget<FieldT>(pb, permuted_vars[i],
                                                                threshold[i],
                                                                comparison_results[i], diff[i], gadget_name);
        }

        decompositionCheckGadgets = new DecompositionCheckGadget<FieldT>(pb, diff, diff_decomposition,
                                                                         path_length - 1, 32,
                                                                         annotation + "decomposition_check_gadget");
        equalityCheckGadget = new EqualityCheckGadget<FieldT>(pb, target_class_label, class_label,
                                                              correct, 32, annotation + "equality_check_gadget");
        equalityCheckGadgetP = new EqualityCheckGadget<FieldT>(pb, positive_class_label, class_label,
                                                              positive, 32, annotation + "equality_check_gadget");
    }

    ~PathPredictionGadget() {

        delete[] raw_vars;
        delete[] raw_index;

        // extended witness: permutation
        delete[] permuted_vars;
        delete[] permuted_index;

        // secret: v1, v2, ... vh
        delete[] node_id;
        delete[] variable_id;
        delete[] threshold;
        delete[] l_node_id;
        delete[] r_node_id;

        // helper variables
        delete pairwiseMultiSetGadget;
        delete[] _variable_count;
        delete[] frequency_in_bits;

        for (int i = 0; i < path_length - 1; ++i) {
            (comparisonGadgets + i)->~ComparisonGadget();
        }
        free(comparisonGadgets);

        delete[] comparison_results;

        for (int i = 0; i < 8; ++i) {
            (decompositionCheckGadgets + i)->~DecompositionCheckGadget();
        }
        free(decompositionCheckGadgets);
    }

    void generate_r1cs_constraints() {
        this->_multiset_generate_r1cs_constraints();
        this->_bit_decomposition_generate_r1cs_constraints();
        this->_dt_prediction_generate_r1cs_constraints();
    }

    void generate_r1cs_witness() {
        this->_general_generate_r1cs_witness();
        this->_multiset_generate_r1cs_witness();
        this->_bit_decomposition_generate_r1cs_witness();
        this->_dt_prediction_generate_r1cs_witness();
    }
};


template<typename FieldT>
class DTBatchGadget : public gadget<FieldT> {
    const static unsigned N_BITS_NODE_ATTR = 32;
private:
    std::map<unsigned, unsigned> leaf_id_index;
    std::map<unsigned, unsigned> non_leaf_id_index;

    pb_variable <FieldT> *node_id;
    pb_variable <FieldT> *variable_id;
    pb_variable <FieldT> *threshold;
    pb_variable <FieldT> *l_node_id;
    pb_variable <FieldT> *r_node_id;
    pb_variable <FieldT> *class_label;

    pb_variable <FieldT> *node_id_decomposition;
    pb_variable <FieldT> *variable_id_decomposition;
    pb_variable <FieldT> *threshold_decomposition;
    pb_variable <FieldT> *l_node_id_decomposition;
    pb_variable <FieldT> *r_node_id_decomposition;
    pb_variable <FieldT> *class_label_decomposition;

    pb_variable <FieldT> *hash_inputs_1;
    pb_variable <FieldT> *hash_inputs_2;
    pb_variable <FieldT> *hash_outputs_1;
    pb_variable <FieldT> *hash_outputs_2;

    pb_variable <FieldT> *commitment;

    pb_variable <FieldT> zero_var;

    pb_variable <FieldT> coef, challenge_point;

    pb_variable <FieldT> *coef_array;

    pb_variable <FieldT> **tree_nodes_values;
    pb_variable <FieldT> *tree_nodes_terms;
    pb_variable <FieldT> **path_nodes_values;
    pb_variable <FieldT> *path_nodes_terms;

    pb_variable <FieldT> *target_labels;
    pb_variable <FieldT> n_correct_var;

    // variables for fairness metrics
    pb_variable <FieldT> thr_scaled_var;
    pb_variable <FieldT> n_pos_zero_var;
    pb_variable <FieldT> n_pos_one_var;
    pb_variable <FieldT> n_pos_zero_scaled_var;
    pb_variable <FieldT> n_pos_one_scaled_var;

    pb_variable <FieldT> thr_scaled_tpos_var;
    pb_variable <FieldT> n_pos_zero_tpos_var;
    pb_variable <FieldT> n_pos_one_tpos_var;
    pb_variable <FieldT> n_pos_zero_tpos_scaled_var;
    pb_variable <FieldT> n_pos_one_tpos_scaled_var;
    pb_variable <FieldT> thr_scaled_tneg_var;
    pb_variable <FieldT> n_pos_zero_tneg_var;
    pb_variable <FieldT> n_pos_one_tneg_var;
    pb_variable <FieldT> n_pos_zero_tneg_scaled_var;
    pb_variable <FieldT> n_pos_one_tneg_scaled_var;
    pb_variable <FieldT> sum_res_zero_var;
    pb_variable <FieldT> sum_res_one_var;
    pb_variable <FieldT> var_res_zero_var;
    pb_variable <FieldT> var_res_one_var;
    pb_variable <FieldT> *temp_var;


    // variables for absolute value check
    pb_variable <FieldT> is_lesseq;
    pb_variable <FieldT> abs_var;
    pb_variable <FieldT> is_lesseq_tpos;
    pb_variable <FieldT> abs_tpos_var;
    pb_variable <FieldT> is_lesseq_tneg;
    pb_variable <FieldT> abs_tneg_var;
    // variables for fairness check
    pb_variable <FieldT> is_fair;
    pb_variable <FieldT> diff;
    pb_variable <FieldT> is_fair_tpos;
    pb_variable <FieldT> diff_tpos;
    pb_variable <FieldT> is_fair_tneg;
    pb_variable <FieldT> diff_tneg;

    // bit decompositions
    pb_variable <FieldT> *thr_scaled_var_decomposition;
    pb_variable <FieldT> *thr_scaled_tpos_var_decomposition;
    pb_variable <FieldT> *thr_scaled_tneg_var_decomposition;
    pb_variable <FieldT> *n_pos_zero_scaled_var_decomposition;
    pb_variable <FieldT> *n_pos_one_scaled_var_decomposition;
    pb_variable <FieldT> *n_pos_zero_tpos_scaled_var_decomposition;
    pb_variable <FieldT> *n_pos_one_tpos_scaled_var_decomposition;
    pb_variable <FieldT> *n_pos_zero_tneg_scaled_var_decomposition;
    pb_variable <FieldT> *n_pos_one_tneg_scaled_var_decomposition;
    pb_variable <FieldT> *abs_decomposition;
    pb_variable <FieldT> *diff_decomposition;
    pb_variable <FieldT> *abs_tpos_decomposition;
    pb_variable <FieldT> *diff_tpos_decomposition;
    pb_variable <FieldT> *abs_tneg_decomposition;
    pb_variable <FieldT> *diff_tneg_decomposition;

    unsigned n_frequency_bits;
    std::vector<unsigned> nodes_count;
    pb_variable <FieldT> *frequency_in_bits;

private:
    DecompositionCheckGadget <FieldT> *decompositionCheckGadgets;
    SwifftGadget <FieldT> *swifftGadget;
    PathPredictionGadget<FieldT> *pathPredictionGadget;
    LinearCombinationGadget <FieldT> *treeLinearCombinationGadget;
    LinearCombinationGadget <FieldT> *pathLinearCombinationGadget;
    MultiSetGadget <FieldT> *multisetGadget;
    // fairness
    ComparisonGadget <FieldT> *comparisonGadgetA;
    ComparisonGadget <FieldT> *comparisonGadgetF;
    ComparisonGadget <FieldT> *comparisonGadgetAtneg;
    ComparisonGadget <FieldT> *comparisonGadgetFtneg;
    DecompositionCheckGadget <FieldT> *decompositionCheckGadgetsF;


private:

    void _init_id_map() {
        std::vector < DTNode * > nodes = dt.get_all_nodes();
        unsigned non_leaf_count = 0, leaf_count = 0;
        for (DTNode *node : nodes) {
            if (node->is_leaf) {
                leaf_id_index[node->node_id] = leaf_count++;
            } else {
                non_leaf_id_index[node->node_id] = non_leaf_count++;
            }
        }
    }

    void _count() {
        nodes_count.resize(dt.n_nodes, 0);
        for (auto &path_nodes: all_paths) {
            for (auto &node: path_nodes) {
                unsigned index;
                if (node->is_leaf) {
                    index = leaf_id_index[node->node_id] + dt.root->non_leaf_size;
                } else {
                    index = non_leaf_id_index[node->node_id];
                }
                nodes_count[index]++;
            }
        }
        n_frequency_bits = log(all_paths.size()) / log(2) + 1;
    }

    void _init_pb_vars() {

        std::string prefix = this->annotation_prefix;
        zero_var.allocate(this->pb, prefix + "zero_var");
        coef.allocate(this->pb, prefix + "coef");
        challenge_point.allocate(this->pb, prefix + "challenge_point");

        _init_pb_array(this->pb, node_id, dt.n_nodes, prefix + "node_id");
        _init_pb_array(this->pb, variable_id, dt.root->non_leaf_size, prefix + "variable_id");
        _init_pb_array(this->pb, threshold, dt.root->non_leaf_size, prefix + "threshold");
        _init_pb_array(this->pb, l_node_id, dt.root->non_leaf_size, prefix + "l_node_id");
        _init_pb_array(this->pb, r_node_id, dt.root->non_leaf_size, prefix + "r_node_id");
        _init_pb_array(this->pb, class_label, dt.n_nodes - dt.root->non_leaf_size, prefix + "class_label");

        // no need to allocate inputs, because they are placeholders, and are assigned other variables
        hash_inputs_1 = new pb_variable<FieldT>[dt.root->non_leaf_size * _hash_input_size];
        hash_inputs_2 = new pb_variable<FieldT>[dt.root->non_leaf_size * _hash_input_size];
        _init_pb_array(this->pb, hash_outputs_1, dt.root->non_leaf_size * _hash_output_size, prefix + "hash_outputs_1");
        _init_pb_array(this->pb, hash_outputs_2, dt.root->non_leaf_size * _hash_output_size, prefix + "hash_outputs_2");
        _init_pb_array(this->pb, commitment, _hash_output_size, prefix + "commitment");

        _init_pb_array(this->pb, node_id_decomposition, dt.n_nodes * N_BITS_NODE_ATTR,
                       prefix + "node_id_decomposition");
        _init_pb_array(this->pb, variable_id_decomposition, dt.root->non_leaf_size * N_BITS_NODE_ATTR,
                       prefix + "variable_id_decomposition");
        _init_pb_array(this->pb, threshold_decomposition, dt.root->non_leaf_size * N_BITS_NODE_ATTR,
                       prefix + "threshold_decomposition");
        _init_pb_array(this->pb, l_node_id_decomposition, dt.root->non_leaf_size * N_BITS_NODE_ATTR,
                       prefix + "l_node_id_decomposition");
        _init_pb_array(this->pb, r_node_id_decomposition, dt.root->non_leaf_size * N_BITS_NODE_ATTR,
                       prefix + "r_node_id_decomposition");
        _init_pb_array(this->pb, class_label_decomposition, (dt.n_nodes - dt.root->non_leaf_size) * N_BITS_NODE_ATTR,
                       prefix + "class_label_decomposition");
        // add decomposition of fairness threshold 

        _init_pb_array(this->pb, coef_array, dt.n_nodes, prefix + "coef_array");

        tree_nodes_values = new pb_variable <FieldT> *[dt.n_nodes];
        path_nodes_values = new pb_variable <FieldT> *[n_path_nodes];
        _init_pb_array(this->pb, tree_nodes_terms, dt.n_nodes, prefix + "tree_nodes_terms");
        _init_pb_array(this->pb, path_nodes_terms, n_path_nodes, prefix + "path_nodes_terms");

        _init_pb_array(this->pb, frequency_in_bits, n_frequency_bits * dt.n_nodes, prefix + "frequency_in_bits");

        _init_pb_array(this->pb, target_labels, data.size(), prefix + "target_labels");
        n_correct_var.allocate(this->pb, prefix + "n_correct_var");

        // fairness
        thr_scaled_var.allocate(this->pb, prefix + "thr_scaled_var");
        n_pos_zero_var.allocate(this->pb, prefix + "n_pos_zero_var");
        n_pos_one_var.allocate(this->pb, prefix + "n_pos_one_var");
        n_pos_zero_scaled_var.allocate(this->pb, prefix + "n_pos_zero_scaled_var");
        n_pos_one_scaled_var.allocate(this->pb, prefix + "n_pos_one_scaled_var");

        thr_scaled_tpos_var.allocate(this->pb, prefix + "thr_scaled_var");
        n_pos_zero_tpos_var.allocate(this->pb, prefix + "n_pos_zero_tpos_var");
        n_pos_one_tpos_var.allocate(this->pb, prefix + "n_pos_one_tpos_var");
        n_pos_zero_tpos_scaled_var.allocate(this->pb, prefix + "n_pos_zero_tpos_scaled_var");
        n_pos_one_tpos_scaled_var.allocate(this->pb, prefix + "n_pos_one_tpos_scaled_var");

        thr_scaled_tneg_var.allocate(this->pb, prefix + "thr_scaled_tneg_var");
        n_pos_zero_tneg_var.allocate(this->pb, prefix + "n_pos_zero_tneg_var");
        n_pos_one_tneg_var.allocate(this->pb, prefix + "n_pos_one_tneg_var");
        n_pos_zero_tneg_scaled_var.allocate(this->pb, prefix + "n_pos_zero_tneg_scaled_var");
        n_pos_one_tneg_scaled_var.allocate(this->pb, prefix + "n_pos_one_tneg_scaled_var");
        // MRD
        sum_res_zero_var.allocate(this->pb, prefix + "sum_res_zero_var");
        sum_res_one_var.allocate(this->pb, prefix + "sum_res_one_var");
        var_res_zero_var.allocate(this->pb, prefix + "var_res_zero_var");
        var_res_one_var.allocate(this->pb, prefix + "var_res_one_var");
        _init_pb_array(this->pb, temp_var, data.size(), prefix + "temp_var");
    
        is_lesseq.allocate(this->pb, prefix + "is_lesseq");
        abs_var.allocate(this->pb, prefix + "abs_var");
        
        is_fair.allocate(this->pb, prefix + "is_fair");
        diff.allocate(this->pb, prefix + "diff");

        is_lesseq_tpos.allocate(this->pb, prefix + "is_lesseq");
        abs_tpos_var.allocate(this->pb, prefix + "abs_var");
        
        is_fair_tpos.allocate(this->pb, prefix + "is_fair");
        diff_tpos.allocate(this->pb, prefix + "diff");

        is_lesseq_tneg.allocate(this->pb, prefix + "is_lesseq");
        abs_tneg_var.allocate(this->pb, prefix + "abs_var");
        
        is_fair_tneg.allocate(this->pb, prefix + "is_fair");
        diff_tneg.allocate(this->pb, prefix + "diff");
        
        // decompositions
        _init_pb_array(this->pb, thr_scaled_var_decomposition, 32, prefix + "thr_scaled_var_decomposition");
        _init_pb_array(this->pb, n_pos_zero_scaled_var_decomposition, 32, prefix + "n_pos_zero_scaled_var_decomposition");
        _init_pb_array(this->pb, n_pos_one_scaled_var_decomposition, 32, prefix + "n_pos_one_scaled_var_decomposition");
        _init_pb_array(this->pb, abs_decomposition, 32, prefix + "abs_decomposition");
        _init_pb_array(this->pb, diff_decomposition, 32, prefix + "diff_decomposition");
        
        _init_pb_array(this->pb, thr_scaled_tpos_var_decomposition, 32, prefix + "thr_scaled_tpos_var_decomposition");
        _init_pb_array(this->pb, n_pos_zero_tpos_scaled_var_decomposition, 32, prefix + "n_pos_zero_tpos_scaled_var_decomposition");
        _init_pb_array(this->pb, n_pos_one_tpos_scaled_var_decomposition, 32, prefix + "n_pos_one_tpos_scaled_var_decomposition");
        _init_pb_array(this->pb, abs_tpos_decomposition, 32, prefix + "abs_tpos_decomposition");
        _init_pb_array(this->pb, diff_tpos_decomposition, 32, prefix + "diff_tpos_decomposition");

        _init_pb_array(this->pb, thr_scaled_tneg_var_decomposition, 32, prefix + "thr_scaled_tneg_var_decomposition");
        _init_pb_array(this->pb, n_pos_zero_tneg_scaled_var_decomposition, 32, prefix + "n_pos_zero_tneg_scaled_var_decomposition");
        _init_pb_array(this->pb, n_pos_one_tneg_scaled_var_decomposition, 32, prefix + "n_pos_one_tneg_scaled_var_decomposition");
        _init_pb_array(this->pb, abs_tneg_decomposition, 32, prefix + "abs_tneg_decomposition");
        _init_pb_array(this->pb, diff_tneg_decomposition, 32, prefix + "diff_tneg_decomposition");

        this->pb.set_input_sizes(1);
    }

    void _init_sub_gadgets() {

        std::string prefix = this->annotation_prefix;

        auto &pb = this->pb;
        decompositionCheckGadgets = (DecompositionCheckGadget <FieldT> *) malloc(
                sizeof(DecompositionCheckGadget < FieldT > ) * 6);

        for (int i = 0; i < 6; ++i) {
            new(decompositionCheckGadgets + 0) DecompositionCheckGadget<FieldT>(pb, node_id, node_id_decomposition,
                                                                                dt.n_nodes, N_BITS_NODE_ATTR,
                                                                                prefix +
                                                                                "decomposition_check_gadget_0");
            new(decompositionCheckGadgets + 1) DecompositionCheckGadget<FieldT>(pb, variable_id,
                                                                                variable_id_decomposition,
                                                                                dt.root->non_leaf_size,
                                                                                N_BITS_NODE_ATTR,
                                                                                prefix +
                                                                                "decomposition_check_gadget_1");
            new(decompositionCheckGadgets + 2) DecompositionCheckGadget<FieldT>(pb, threshold, threshold_decomposition,
                                                                                dt.root->non_leaf_size,
                                                                                N_BITS_NODE_ATTR,
                                                                                prefix +
                                                                                "decomposition_check_gadget_2");
            new(decompositionCheckGadgets + 3) DecompositionCheckGadget<FieldT>(pb, l_node_id, l_node_id_decomposition,
                                                                                dt.root->non_leaf_size,
                                                                                N_BITS_NODE_ATTR,
                                                                                prefix +
                                                                                "decomposition_check_gadget_3");
            new(decompositionCheckGadgets + 4) DecompositionCheckGadget<FieldT>(pb, r_node_id, r_node_id_decomposition,
                                                                                dt.root->non_leaf_size,
                                                                                N_BITS_NODE_ATTR,
                                                                                prefix +
                                                                                "decomposition_check_gadget_4");
            new(decompositionCheckGadgets + 5) DecompositionCheckGadget<FieldT>(pb, class_label,
                                                                                class_label_decomposition,
                                                                                dt.n_nodes - dt.root->non_leaf_size,
                                                                                N_BITS_NODE_ATTR,
                                                                                prefix +
                                                                                "decomposition_check_gadget_5");
        }

        swifftGadget = (SwifftGadget <FieldT> *) malloc(sizeof(SwifftGadget < FieldT > ) * dt.root->non_leaf_size * 2);
        for (int i = 0; i < dt.root->non_leaf_size; ++i) {
            auto gadget_name = prefix + std::string("hash_gadget_") + std::to_string(i);
            new(swifftGadget + i * 2) SwifftGadget<FieldT>(pb, hash_inputs_1 + i * _hash_input_size,
                                                           hash_outputs_1 + i * _hash_output_size,
                                                           gadget_name + "first");
            new(swifftGadget + i * 2 + 1) SwifftGadget<FieldT>(pb, hash_inputs_2 + i * _hash_input_size,
                                                               hash_outputs_2 + i * _hash_output_size,
                                                               gadget_name + "second");
        }

        pathPredictionGadget = (PathPredictionGadget<FieldT> *) malloc(
                sizeof(PathPredictionGadget<FieldT>) * data.size());

        for (int i = 0; i < data.size(); ++i) {
            std::vector<unsigned> &single_data = data[i];
            new(pathPredictionGadget + i) PathPredictionGadget<FieldT>(this->pb, dt, single_data,
                                                                       labels[i], pos_label,
                                                                       coef, challenge_point, this->annotation_prefix +
                                                                                              "path_prediction_gadget" +
                                                                                              std::to_string(i));
        }

        unsigned index, length;
        treeLinearCombinationGadget =
                (LinearCombinationGadget <FieldT> *) malloc(sizeof(LinearCombinationGadget < FieldT > ) * dt.n_nodes);
        for (index = 0; index < dt.n_nodes; ++index) {
            if (index < dt.root->non_leaf_size) {
                tree_nodes_values[index] = new pb_variable<FieldT>[5];

                tree_nodes_values[index][0] = node_id[index];
                tree_nodes_values[index][1] = variable_id[index];
                tree_nodes_values[index][2] = threshold[index];
                tree_nodes_values[index][3] = l_node_id[index];
                tree_nodes_values[index][4] = r_node_id[index];

                length = 5;
            } else {
                tree_nodes_values[index] = new pb_variable<FieldT>[2];
                tree_nodes_values[index][0] = class_label[index - dt.root->non_leaf_size];
                tree_nodes_values[index][1] = node_id[index];
                length = 2;
            }

            new(treeLinearCombinationGadget + index) LinearCombinationGadget<FieldT>(this->pb, tree_nodes_values[index],
                                                                                     coef_array,
                                                                                     tree_nodes_terms[index], length,
                                                                                     this->annotation_prefix +
                                                                                     "tree_linear_combination_gadgets" +
                                                                                     std::to_string(index));
        }

        pathLinearCombinationGadget =
                (LinearCombinationGadget <FieldT> *) malloc(sizeof(LinearCombinationGadget < FieldT > ) * n_path_nodes);
        index = 0;
        for (int i = 0; i < all_paths.size(); ++i) {
            for (int j = 0; j < all_paths[i].size(); ++j) {
                if (j == all_paths[i].size() - 1) {
                    path_nodes_values[index] = new pb_variable<FieldT>[2];
                    path_nodes_values[index][0] = pathPredictionGadget[i].class_label;
                    path_nodes_values[index][1] = pathPredictionGadget[i].node_id[j];
                    length = 2;
                } else {
                    path_nodes_values[index] = new pb_variable<FieldT>[5];

                    path_nodes_values[index][0] = pathPredictionGadget[i].node_id[j];
                    path_nodes_values[index][1] = pathPredictionGadget[i].variable_id[j];
                    path_nodes_values[index][2] = pathPredictionGadget[i].threshold[j];
                    path_nodes_values[index][3] = pathPredictionGadget[i].l_node_id[j];
                    path_nodes_values[index][4] = pathPredictionGadget[i].r_node_id[j];
                    length = 5;
                }
                new(pathLinearCombinationGadget + index) LinearCombinationGadget<FieldT>(this->pb,
                                                                                         path_nodes_values[index],
                                                                                         coef_array,
                                                                                         path_nodes_terms[index],
                                                                                         length,
                                                                                         this->annotation_prefix +
                                                                                         "path_linear_combination_gadgets" +
                                                                                         std::to_string(index));
                index++;
            }
        }

        multisetGadget = new MultiSetGadget<FieldT>(this->pb, tree_nodes_terms, path_nodes_terms,
                                                    frequency_in_bits, dt.n_nodes, n_path_nodes, n_frequency_bits,
                                                    challenge_point, prefix + "multisetGadget");

        // initialize comparison and decomposition gadgets
        switch (fairness_mode) {
            case 1:
                comparisonGadgetA = new ComparisonGadget<FieldT>(this->pb, n_pos_zero_scaled_var, n_pos_one_scaled_var,
                                                            is_lesseq, abs_var, prefix + "comparisonGadget");
                comparisonGadgetF = new ComparisonGadget<FieldT>(this->pb, abs_var, thr_scaled_var,
                                                            is_fair, diff, prefix + "comparisonGadget");
                decompositionCheckGadgetsF = (DecompositionCheckGadget <FieldT> *) malloc(
                        sizeof(DecompositionCheckGadget < FieldT > ) * 5);
                new(decompositionCheckGadgetsF + 0) DecompositionCheckGadget<FieldT>(pb, &n_pos_zero_scaled_var, n_pos_zero_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 1) DecompositionCheckGadget<FieldT>(pb, &n_pos_one_scaled_var, n_pos_one_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 2) DecompositionCheckGadget<FieldT>(pb, &abs_var, abs_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 3) DecompositionCheckGadget<FieldT>(pb, &thr_scaled_var, thr_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 4) DecompositionCheckGadget<FieldT>(pb, &diff, diff_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                break;
            case 2:
                comparisonGadgetA = new ComparisonGadget<FieldT>(this->pb, n_pos_zero_tpos_scaled_var, n_pos_one_tpos_scaled_var,
                                                            is_lesseq_tpos, abs_tpos_var, prefix + "comparisonGadget");
                comparisonGadgetF = new ComparisonGadget<FieldT>(this->pb, abs_tpos_var, thr_scaled_tpos_var,
                                                            is_fair_tpos, diff_tpos, prefix + "comparisonGadget");
                comparisonGadgetAtneg = new ComparisonGadget<FieldT>(this->pb, n_pos_zero_tneg_scaled_var, n_pos_one_tneg_scaled_var,
                                                            is_lesseq_tneg, abs_tneg_var, prefix + "comparisonGadget");
                comparisonGadgetFtneg = new ComparisonGadget<FieldT>(this->pb, abs_tneg_var, thr_scaled_tneg_var,
                                                            is_fair_tneg, diff_tneg, prefix + "comparisonGadget");    
                decompositionCheckGadgetsF = (DecompositionCheckGadget <FieldT> *) malloc(
                        sizeof(DecompositionCheckGadget < FieldT > ) * 10);
                new(decompositionCheckGadgetsF + 0) DecompositionCheckGadget<FieldT>(pb, &n_pos_zero_tpos_scaled_var, n_pos_zero_tpos_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 1) DecompositionCheckGadget<FieldT>(pb, &n_pos_one_tpos_scaled_var, n_pos_one_tpos_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 2) DecompositionCheckGadget<FieldT>(pb, &abs_tpos_var, abs_tpos_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 3) DecompositionCheckGadget<FieldT>(pb, &thr_scaled_tpos_var, thr_scaled_tpos_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 4) DecompositionCheckGadget<FieldT>(pb, &diff_tpos, diff_tpos_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 5) DecompositionCheckGadget<FieldT>(pb, &n_pos_zero_tneg_scaled_var, n_pos_zero_tneg_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 6) DecompositionCheckGadget<FieldT>(pb, &n_pos_one_tneg_scaled_var, n_pos_one_tneg_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 7) DecompositionCheckGadget<FieldT>(pb, &abs_tneg_var, abs_tneg_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 8) DecompositionCheckGadget<FieldT>(pb, &thr_scaled_tneg_var, thr_scaled_tneg_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 9) DecompositionCheckGadget<FieldT>(pb, &diff_tneg, diff_tneg_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                
                break;
            case 3:
                /*
                comparisonGadgetF = new ComparisonGadget<FieldT>(this->pb, thr_scaled_var, n_pos_one_scaled_var,
                                                            is_fair, diff, prefix + "comparisonGadget");
                decompositionCheckGadgetsF = (DecompositionCheckGadget <FieldT> *) malloc(
                        sizeof(DecompositionCheckGadget < FieldT > ) * 3);
                new(decompositionCheckGadgetsF + 0) DecompositionCheckGadget<FieldT>(pb, &n_pos_one_scaled_var, n_pos_one_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 1) DecompositionCheckGadget<FieldT>(pb, &thr_scaled_var, thr_scaled_var_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                new(decompositionCheckGadgetsF + 2) DecompositionCheckGadget<FieldT>(pb, &diff, diff_decomposition,
                                                                                1, 32,
                                                                                prefix + "decomposition_check_gadget");
                */
                break;
         }
    }


private:

    void _decompose(unsigned value, pb_variable <FieldT> *bits) {
        for (int i = 0; i < N_BITS_NODE_ATTR; ++i) {
            eval(bits[i]) = (value >> (N_BITS_NODE_ATTR - 1 - i)) & 1U;
        }
    }

    void _general_constraints() {
        std::string prefix = this->annotation_prefix;
        auto sum = linear_combination<FieldT>();
        auto sum_pos_zero = linear_combination<FieldT>();
        auto sum_pos_one = linear_combination<FieldT>();
        auto sum_pos_zero_tneg = linear_combination<FieldT>();
        auto sum_pos_one_tneg = linear_combination<FieldT>();
        auto sum_res_zero = linear_combination<FieldT>(); // sum of residuals condintioned on sensitive attribute = 0
        auto sum_res_one = linear_combination<FieldT>(); // sum of residuals condintioned on sensitive attribute = 1
        auto var_res_zero = linear_combination<FieldT>(); // sum of condintioned on sensitive attribute = 0
        auto var_res_one = linear_combination<FieldT>(); // sum of condintioned on sensitive attribute = 1

        for (int i = 0; i < data.size(); ++i) {
            sum = sum + pathPredictionGadget[i].correct;
        }
        std::cout << "computed _general_constraints()" << std::endl;
        add_r1cs(sum, 1, n_correct_var);
        
        switch (fairness_mode) {
            case 1:
                for (int i = 0; i < data.size(); ++i) {
                    if (data[i][sensitive] <= thr_sensitive ) {
                        sum_pos_zero = sum_pos_zero + pathPredictionGadget[i].positive * n_sens_one * scaling_thr;
                    } else {
                        sum_pos_one = sum_pos_one + pathPredictionGadget[i].positive * n_sens_zero * scaling_thr;
                    }
                }
                // check correctness of intermediate statistics
                add_r1cs(sum_pos_zero, 1, n_pos_zero_scaled_var); 
                add_r1cs(sum_pos_one, 1, n_pos_one_scaled_var);
                // conditional assignment
                add_r1cs(is_lesseq, 2*n_pos_one_scaled_var - 2*n_pos_zero_scaled_var, 
                                    abs_var - (n_pos_zero_scaled_var - n_pos_one_scaled_var));
                comparisonGadgetA->generate_r1cs_constraints();
                comparisonGadgetF->generate_r1cs_constraints();
                break;
            case 2:
                for (int i = 0; i < data.size(); ++i) {
                    if (pos_label == labels[i]) {
                        if (data[i][sensitive] <= thr_sensitive) {
                            sum_pos_zero = sum_pos_zero + pathPredictionGadget[i].positive * n_sens_one_tpos * scaling_thr;
                        } else {
                            sum_pos_one = sum_pos_one + pathPredictionGadget[i].positive * n_sens_zero_tpos * scaling_thr;
                        }
                    } else {
                        if (data[i][sensitive] <= thr_sensitive) {
                            sum_pos_zero_tneg = sum_pos_zero_tneg + pathPredictionGadget[i].positive * n_sens_one_tneg * scaling_thr;
                        } else {
                            sum_pos_one_tneg = sum_pos_one_tneg + pathPredictionGadget[i].positive * n_sens_zero_tneg * scaling_thr;
                        }
                    }
                }
                // check correctness of intermediate statistics
                //add_r1cs(sum_pos_zero, 1, n_pos_zero_tpos_scaled_var);
                //std::cout << "sum_pos_zero_int" << sum_pos_zero_int << std::endl;
                //add_r1cs(sum_pos_one, 1, n_pos_one_tpos_scaled_var);
                //add_r1cs(sum_pos_zero_tneg, 1, n_pos_zero_tpos_scaled_var); 
                //add_r1cs(sum_pos_one_tneg, 1, n_pos_one_tpos_scaled_var);
                // conditional assignment
                add_r1cs(is_lesseq_tpos, 2*n_pos_one_tpos_scaled_var - 2*n_pos_zero_tpos_scaled_var, 
                                    abs_tpos_var - (n_pos_zero_tpos_scaled_var - n_pos_one_tpos_scaled_var));
                add_r1cs(is_lesseq_tneg, 2*n_pos_one_tneg_scaled_var - 2*n_pos_zero_tneg_scaled_var, 
                                    abs_tneg_var - (n_pos_zero_tneg_scaled_var - n_pos_one_tneg_scaled_var));
                comparisonGadgetA->generate_r1cs_constraints();
                comparisonGadgetF->generate_r1cs_constraints();
                comparisonGadgetAtneg->generate_r1cs_constraints();
                comparisonGadgetFtneg->generate_r1cs_constraints();
                break;
            case 3:
                for (int i = 0; i < data.size(); ++i) {
                    if (data[i][sensitive] <= thr_sensitive ) {
                        sum_res_zero = sum_res_zero + pathPredictionGadget[i].target_class_label - pathPredictionGadget[i].class_label;
                    } else {
                        sum_res_one = sum_res_one + pathPredictionGadget[i].target_class_label - pathPredictionGadget[i].class_label;
                    }
                }
                add_r1cs(sum_res_zero, 1, sum_res_zero_var); 
                add_r1cs(sum_res_one, 1, sum_res_one_var);
                
                for (int i = 0; i < data.size(); i++) {
                    //std::vector<unsigned> &single_data = data[i];
                    auto temp_sd = linear_combination<FieldT>();                    
                    if (data[i][sensitive] <= thr_sensitive ) {
                        temp_sd = n_sens_zero * (pathPredictionGadget[i].target_class_label 
                                                            - pathPredictionGadget[i].class_label) - sum_res_zero;
                        var_res_zero = var_res_zero + temp_var[i];
                    } else {
                        temp_sd = n_sens_one * (pathPredictionGadget[i].target_class_label 
                                                            - pathPredictionGadget[i].class_label) - sum_res_one;
                        var_res_one = var_res_one + temp_var[i];
                    }
                    add_r1cs(temp_sd, temp_sd, temp_var[i]);
                }

                add_r1cs(var_res_zero, 1, var_res_zero_var); 
                add_r1cs(var_res_one, 1, var_res_one_var); 
                //std::cout << "var_res_one = " << eval(var_res_one) << std::endl;
                //std::cout << "var_res_one_var = " << eval(var_res_one_var) << std::endl;
                /*
                // conditional assignment
                //add_r1cs(is_lesseq, 2*n_pos_one_scaled_var - 2*n_pos_zero_scaled_var, 
                //                    abs_var - (n_pos_zero_scaled_var - n_pos_one_scaled_var));
                //comparisonGadgetA->generate_r1cs_constraints();
                comparisonGadgetF->generate_r1cs_constraints();
                */
                break;
        }
    }

    void _general_witness() {
        std::vector < DTNode * > nodes = dt.get_all_nodes();
        for (DTNode *node: nodes) {
            if (node->is_leaf) {
                unsigned index = leaf_id_index[node->node_id];
                DTLeaf *leaf = ((DTLeaf *) node);

                eval(class_label[index]) = leaf->class_id;
                eval(node_id[dt.root->non_leaf_size + index]) = leaf->node_id;
                // move these to sub gadgets...
                _decompose(leaf->class_id, class_label_decomposition + N_BITS_NODE_ATTR * index);
                _decompose(leaf->node_id, node_id_decomposition + N_BITS_NODE_ATTR * (dt.root->non_leaf_size + index));

            } else {
                unsigned index = non_leaf_id_index[node->node_id];

                DTInternalNode *internalNode = (DTInternalNode *) node;
                eval(node_id[index]) = internalNode->node_id;
                eval(variable_id[index]) = internalNode->variable_id;
                eval(threshold[index]) = internalNode->threshold;
                eval(l_node_id[index]) = internalNode->l->node_id;
                eval(r_node_id[index]) = internalNode->r->node_id;

                // move these to sub gadgets...
                _decompose(internalNode->node_id, node_id_decomposition + index * N_BITS_NODE_ATTR);
                _decompose(internalNode->variable_id, variable_id_decomposition + +index * N_BITS_NODE_ATTR);
                _decompose(internalNode->threshold, threshold_decomposition + index * N_BITS_NODE_ATTR);
                _decompose(internalNode->l->node_id, l_node_id_decomposition + index * N_BITS_NODE_ATTR);
                _decompose(internalNode->r->node_id, r_node_id_decomposition + index * N_BITS_NODE_ATTR);
            }
        }
        if (fairness_mode > 0) {
            if (fairness_mode < 3) {
                comparisonGadgetA->generate_r1cs_witness();
                comparisonGadgetF->generate_r1cs_witness();
            }
            if (fairness_mode == 2) {
                comparisonGadgetAtneg->generate_r1cs_witness();
            }
        }
    }

    void _decomposition_constraints() {
        for (int i = 0; i < 6; ++i) {
            decompositionCheckGadgets[i].generate_r1cs_constraints();
        }
        // fairness
        if (fairness_mode > 0) {
            unsigned bound;
            switch (fairness_mode ) {
                case 1:
                    bound = 5;
                    break;
                case 2:
                    bound = 10;
                    break;
                case 3:
                    bound = 0;
                    break;
            }
            for (int i = 0; i < bound; ++i) {
                decompositionCheckGadgetsF[i].generate_r1cs_constraints();    
            }
        }
    }

    void _decomposition_witness() {
        for (int i = 0; i < 6; ++i) {
            decompositionCheckGadgets[i].generate_r1cs_witness();
        }
        // fairness
        if (fairness_mode > 0) {
            unsigned bound;
            switch (fairness_mode ) {
                case 1:
                    bound = 5;
                    break;
                case 2:
                    bound = 10;
                    break;
                case 3:
                    bound = 0;
                    break;
            }
            for (int i = 0; i < bound; ++i) {
                decompositionCheckGadgetsF[i].generate_r1cs_witness();
            }
        }
    }

    void _copy(pb_variable <FieldT> *target, pb_variable <FieldT> *source, unsigned size = N_BITS_NODE_ATTR) {
        for (int i = 0; i < size; ++i) {
            target[i] = source[i];
        }
    }

    void _zero_var_assign(pb_variable <FieldT> *a, int start = 0) {
        for (int i = start; i < _hash_output_size; ++i) {
            a[i] = zero_var;
        }
    }

    void _fill_in_input_1(pb_variable <FieldT> *inputs, DTNode *node) {
        if (node->is_leaf) {
            DTLeaf *leaf = (DTLeaf *) node;
            unsigned index = leaf_id_index[node->node_id];

            _copy(inputs, class_label_decomposition + index * N_BITS_NODE_ATTR);
            _copy(inputs + N_BITS_NODE_ATTR,
                  node_id_decomposition + (dt.root->non_leaf_size + index) * N_BITS_NODE_ATTR);
            _zero_var_assign(inputs, N_BITS_NODE_ATTR * 2);

        } else {
            DTInternalNode *internalNode = (DTInternalNode *) node;
            unsigned index = non_leaf_id_index[node->node_id];
            _copy(inputs, hash_outputs_2 + index * _hash_output_size, _hash_output_size);
        }
    }

    void _fill_in_input_2(pb_variable <FieldT> *inputs, DTInternalNode *internalNode) {
        unsigned index = non_leaf_id_index[internalNode->node_id];

        _copy(inputs, hash_outputs_1 + index * _hash_output_size, _hash_output_size);

        inputs = inputs + _hash_output_size;
        _copy(inputs, variable_id_decomposition + index * N_BITS_NODE_ATTR);
        _copy(inputs + N_BITS_NODE_ATTR, threshold_decomposition + index * N_BITS_NODE_ATTR);
        _copy(inputs + N_BITS_NODE_ATTR * 2, node_id_decomposition + index * N_BITS_NODE_ATTR);
        _copy(inputs + N_BITS_NODE_ATTR * 3, l_node_id_decomposition + index * N_BITS_NODE_ATTR);
        _copy(inputs + N_BITS_NODE_ATTR * 4, r_node_id_decomposition + index * N_BITS_NODE_ATTR);

        _zero_var_assign(inputs, N_BITS_NODE_ATTR * 5);

    }

    void _hash_constraints() {
        std::vector < DTNode * > nodes = dt.get_all_nodes();
        for (DTNode *node : nodes) {
            if (!node->is_leaf) {
                DTInternalNode *internalNode = (DTInternalNode *) node;
                unsigned index = non_leaf_id_index[internalNode->node_id];

                _fill_in_input_1(hash_inputs_1 + _hash_input_size * index, internalNode->l);
                _fill_in_input_1(hash_inputs_1 + _hash_input_size * index + _hash_output_size, internalNode->r);
                _fill_in_input_2(hash_inputs_2 + _hash_input_size * index, internalNode);
                swifftGadget[index * 2].generate_r1cs_constraints();
                swifftGadget[index * 2 + 1].generate_r1cs_constraints();
            }
        }

        unsigned root_index = non_leaf_id_index[dt.root->node_id];
        for (int i = 0; i < _hash_output_size; ++i) {
            add_r1cs(hash_outputs_2[root_index * _hash_output_size + i], 1, commitment[i]);
        }
    }

    void _fill_in_output_1(pb_variable <FieldT> *output, DTInternalNode *internalNode) {
        for (int i = 0; i < _hash_output_size; ++i) {
            eval(output[i]) = internalNode->first_hash[i];
        }
    }

    void _fill_in_output_2(pb_variable <FieldT> *output, DTInternalNode *internalNode) {
        for (int i = 0; i < _hash_output_size; ++i) {
            eval(output[i]) = internalNode->hash[i];
        }
    }

    void _hash_witness() {
        eval(zero_var) = 0;
        std::vector < DTNode * > nodes = dt.get_all_nodes();
        for (DTNode *node : nodes) {
            if (!node->is_leaf) {
                DTInternalNode *internalNode = (DTInternalNode *) node;
                unsigned index = non_leaf_id_index[internalNode->node_id];

                _fill_in_output_1(hash_outputs_1 + index * _hash_output_size, internalNode);
                _fill_in_output_2(hash_outputs_2 + index * _hash_output_size, internalNode);

                swifftGadget[index * 2].generate_r1cs_witness(internalNode->intermediate_linear_combination[0]);
                swifftGadget[index * 2 + 1].generate_r1cs_witness(internalNode->intermediate_linear_combination[1]);
            }
        }

        unsigned root_index = non_leaf_id_index[dt.root->node_id];
        for (int i = 0; i < _hash_output_size; ++i) {
            eval(commitment[i]) = dt.root->hash[i];
        }
    }

    void _path_constraints() {
        for (int i = 0; i < data.size(); ++i) {
            pathPredictionGadget[i].generate_r1cs_constraints();
        }
    }

    void _path_witness() {
        for (int i = 0; i < data.size(); ++i) {
            pathPredictionGadget[i].generate_r1cs_witness();
        }
    }

    void _nodes_multiset_constraints() {
        add_r1cs(coef_array[0], 1, coef);

        for (int i = 0; i < dt.n_nodes - 1; ++i) {
            add_r1cs(coef_array[i], coef, coef_array[i + 1]);
        }

        for (int i = 0; i < dt.n_nodes; ++i) {
            treeLinearCombinationGadget[i].generate_r1cs_constraints();
        }

        for (int i = 0; i < n_path_nodes; ++i) {
            pathLinearCombinationGadget[i].generate_r1cs_constraints();
        }

        multisetGadget->generate_r1cs_constraints();
    }

    void _nodes_multiset_witness() {
        eval(coef_array[0]) = eval(coef);
        for (int i = 0; i < dt.n_nodes - 1; ++i) {
            eval(coef_array[i + 1]) = eval(coef) * eval(coef_array[i]);
        }

        for (int i = 0; i < dt.n_nodes; ++i) {
            treeLinearCombinationGadget[i].generate_r1cs_witness();
        }

        for (int i = 0; i < n_path_nodes; ++i) {
            pathLinearCombinationGadget[i].generate_r1cs_witness();
        }

        for (int i = 0; i < dt.n_nodes; ++i) {
            for (int j = 0; j < n_frequency_bits; ++j) {
                eval(frequency_in_bits[i * n_frequency_bits + j]) =
                        (nodes_count[i] >> (n_frequency_bits - j - 1)) & 1U;
            }
        }

        multisetGadget->generate_r1cs_witness();
    }

public:

    unsigned n_path_nodes;
    unsigned n_correct;
    // fairness variables
    unsigned n_sens_zero;
    unsigned n_sens_one;
    unsigned n_pos;
    unsigned n_pos_zero;
    unsigned n_pos_one;
    unsigned n_pos_zero_scaled;
    unsigned n_pos_one_scaled;
    int64_t residual_one;
    int64_t residual_zero;
    int64_t residual_one_scaled;
    int64_t residual_zero_scaled;
    uint64_t var_residual_zero;
    uint64_t var_residual_one;
    // Y = 1
    unsigned n_sens_zero_tpos;
    unsigned n_sens_one_tpos;
    unsigned n_pos_zero_tpos;
    unsigned n_pos_one_tpos;
    unsigned n_pos_zero_tpos_scaled;
    unsigned n_pos_one_tpos_scaled;
    // Y = 0
    unsigned n_sens_zero_tneg;
    unsigned n_sens_one_tneg;
    unsigned n_pos_zero_tneg;
    unsigned n_pos_one_tneg;
    unsigned n_pos_zero_tneg_scaled;
    unsigned n_pos_one_tneg_scaled;
    
    float thr;
    unsigned thr_sensitive;
    unsigned thr_scaled, thr_scaled_tpos, thr_scaled_tneg;
    unsigned sensitive;
    unsigned fairness_mode;
    unsigned pos_label;
    unsigned scaling_thr;
    unsigned scaling_cls;
    float alpha;
    bool verbose;

    DT &dt;
    std::vector <std::vector<unsigned>> &data;
    std::vector <unsigned> &labels;
    std::vector <unsigned> predictions = {};
    std::vector <int64_t> residuals = {};
    std::vector <std::vector<DTNode *>> all_paths;

    DTBatchGadget(protoboard <FieldT> &pb, DT &dt_, std::vector <std::vector<unsigned>> &data_,
                  std::vector<unsigned> labels_,
                  FieldT &coef_,
                  FieldT &challenge_point_, 
                  float &thr_,
                  unsigned &sensitive_,
                  unsigned &thr_sensitive_,
                  unsigned &fairness_mode_,
                  unsigned &pos_label_,
                  unsigned &scaling_cls_,
                  float &alpha_,
                  bool &verbose_,
                  const std::string &annotation = "")
            : gadget<FieldT>(pb, annotation), dt(dt_), data(data_), labels(labels_), 
                            thr(thr_), sensitive(sensitive_), thr_sensitive(thr_sensitive_), 
                            fairness_mode(fairness_mode_), pos_label(pos_label_), 
                            scaling_cls(scaling_cls_), alpha(alpha_), verbose(verbose_) {

        n_path_nodes = 0;
        n_correct = 0;
        n_sens_zero = 0; 
        n_sens_zero_tpos = 0; 
        n_sens_zero_tneg = 0; 
        n_sens_one = 0;
        n_sens_one_tpos = 0;
        n_sens_one_tneg = 0;
        n_pos = 0;
        n_pos_zero = 0;
        n_pos_one = 0;
        n_pos_zero_tpos = 0;
        n_pos_zero_tneg = 0;
        n_pos_one_tpos = 0;
        n_pos_one_tneg = 0;
        residual_zero = 0; // bar_r_0 * n_sens_zero
        residual_one = 0; //  bar_r_1 * n_sens_one
        var_residual_zero = 0; // v_0 * n_sens_zero^2 * (n_sens_zero - 1)
        var_residual_one = 0; // v_1 * n_sens_one^ 2 * (n_sens_one - 1)
        scaling_thr = 10; // scale everything by 10

        for (int i = 0; i < data.size(); i++) {
            std::vector<unsigned> &single_data = data[i];
            std::vector < DTNode * > path_nodes = dt.predict(single_data);
            //std::cout << "sensitive  = " << single_data[sensitive]  << std::endl;
            /*
            if (i == 0) {
                for (auto v : single_data) std::cout << v << std::endl;
                print_node_vector(path_nodes);
            }
            */

            n_path_nodes += path_nodes.size();
            all_paths.push_back(path_nodes);
            
            // if (i < 10) {
            //    printf("label %d: %d\n", i, labels[i]);
            //    printf("predi %d: %d\n", i, ((DTLeaf*) (path_nodes.back()))->class_id);
            // }
            
            //std::cout << "class_id = " << ((DTLeaf*) (path_nodes.back()))->class_id << std::endl;
            // std::cout << "counting" << std::endl;
            n_correct += labels[i] == ((DTLeaf*) (path_nodes.back()))->class_id; // acc
            int64_t pred = ((DTLeaf*) (path_nodes.back()))->class_id; 
            int64_t truth = labels[i]; 
            // fair
            bool is_positive = pos_label == ((DTLeaf*) (path_nodes.back()))->class_id;
            bool is_tpositive = pos_label == labels[i]; // check true positive
            predictions.push_back(pred);
            residuals.push_back(truth - pred);
            if (is_positive) n_pos += 1; 
            if (single_data[sensitive] <= thr_sensitive) {
                n_sens_zero++;
                residual_zero += truth - pred;
                if (is_tpositive) n_sens_zero_tpos++;
                if (!is_tpositive) n_sens_zero_tneg++;
                if (is_positive) n_pos_zero += 1; 
                if (is_positive && is_tpositive) n_pos_zero_tpos += 1; 
                if (is_positive && !is_tpositive) n_pos_zero_tneg += 1; 
            } else {
                n_sens_one++;
                residual_one += truth - pred;
                if (is_tpositive) n_sens_one_tpos++;
                if (!is_tpositive) n_sens_one_tneg++;
                if (is_positive) n_pos_one += 1;
                if (is_positive && is_tpositive) n_pos_one_tpos += 1; 
                if (is_positive && !is_tpositive) n_pos_one_tneg += 1; 
            }
        }
        // Print labels and predictions
        /*
        std::cout << "Labels: ";
        for (size_t i = 0; i < labels.size(); ++i) {
            std::cout << labels[i] << " ";
        }
        std::cout << std::endl;
        std::cout << "Predictions: ";
        for (size_t i = 0; i < predictions.size(); ++i) {
            std::cout << predictions[i] << " ";
        }
        std::cout << std::endl;
        */

        _init_id_map();
        _count();

        _init_pb_vars();
        _init_sub_gadgets();

        // compute variances
        //std::cout << "computing variances" << std::endl;
        for (int i = 0; i < data.size(); i++) {
            std::vector<unsigned> &single_data = data[i];
            if (single_data[sensitive] <= thr_sensitive) {
                int64_t sd = n_sens_zero * residuals[i] - residual_zero;
                eval(temp_var[i]) = sd * sd;
                var_residual_zero += sd * sd;
            } else {
                int64_t sd = n_sens_one * residuals[i] - residual_one;
                eval(temp_var[i]) = sd * sd;
                var_residual_one += sd * sd;
            }
        }
        eval(coef) = coef_;
        eval(challenge_point) = challenge_point_;
        eval(n_correct_var) = n_correct;
        
        // fairness code below
        thr_scaled = thr * n_sens_zero * n_sens_one * scaling_thr;
        eval(n_pos_zero_var) = n_pos_zero;
        eval(n_pos_one_var) = n_pos_one;
        eval(n_pos_zero_tpos_var) = n_pos_zero_tpos;
        eval(n_pos_one_tpos_var) = n_pos_one_tpos;
        eval(n_pos_zero_tneg_var) = n_pos_zero_tneg;
        eval(n_pos_one_tneg_var) = n_pos_one_tneg;
        // MRD
        eval(sum_res_zero_var) = residual_zero;
        eval(sum_res_one_var) = residual_one;
        eval(var_res_zero_var) = var_residual_zero;
        eval(var_res_one_var) = var_residual_one;

        n_pos_zero_scaled = n_pos_zero * n_sens_one * scaling_thr;
        n_pos_one_scaled = n_pos_one * n_sens_zero * scaling_thr;
        n_pos_zero_tpos_scaled = n_pos_zero_tpos * n_sens_one_tpos * scaling_thr;
        n_pos_one_tpos_scaled = n_pos_one_tpos * n_sens_zero_tpos * scaling_thr;
        n_pos_zero_tneg_scaled = n_pos_zero_tneg * n_sens_one_tneg * scaling_thr;
        n_pos_one_tneg_scaled = n_pos_one_tneg * n_sens_zero_tneg * scaling_thr;
        residual_zero_scaled = residual_zero * n_sens_one * scaling_thr;
        residual_one_scaled = residual_one * n_sens_zero * scaling_thr;
        if (fairness_mode == 2) { // EqOd
            thr_scaled_tpos = thr * n_sens_zero_tpos * n_sens_one_tpos * scaling_thr;
            thr_scaled_tneg = thr * n_sens_zero_tneg * n_sens_one_tneg * scaling_thr;
        }
        if (fairness_mode == 3) { // MRD
            thr_scaled_tpos = thr * n_sens_zero_tpos * n_sens_one_tpos * scaling_thr;
            thr_scaled_tneg = thr * n_sens_zero_tneg * n_sens_one_tneg * scaling_thr;
        }
        if (fairness_mode == 4) { // MRD-t
            if (verbose) {
                std::cout << "residual_zero: " << residual_zero << std::endl;
                std::cout << "residual_one : " << residual_one << std::endl;
                std::cout << "var_residual_zero: " << var_residual_zero << std::endl;
                std::cout << "var_residual_one: " << var_residual_one << std::endl;
            }
        }

        eval(n_pos_zero_scaled_var) = n_pos_zero_scaled;
        eval(n_pos_one_scaled_var) = n_pos_one_scaled;
        eval(n_pos_zero_tpos_scaled_var) = n_pos_zero_tpos_scaled;
        eval(n_pos_one_tpos_scaled_var) = n_pos_one_tpos_scaled;
        eval(n_pos_zero_tneg_scaled_var) = n_pos_zero_tneg_scaled;
        eval(n_pos_one_tneg_scaled_var) = n_pos_one_tneg_scaled;
        eval(thr_scaled_var) = thr_scaled;
        eval(thr_scaled_tpos_var) = thr_scaled_tpos;
        eval(thr_scaled_tneg_var) = thr_scaled_tneg;
        
        // check the result first
        bool is_lesseq_tmp, is_lesseq_tpos_tmp, is_lesseq_tneg_tmp;
        bool is_fair_tmp, is_fair_tpos_tmp, is_fair_tneg_tmp;
        unsigned d_abs, d_abs_tpos, d_abs_tneg;
        unsigned d_fair, d_fair_tpos, d_fair_tneg;
        switch (fairness_mode) {
            case 1: 
                is_lesseq_tmp = (n_pos_zero_scaled <= n_pos_one_scaled);
                d_abs = (is_lesseq_tmp) ? (n_pos_one_scaled - n_pos_zero_scaled) : (n_pos_zero_scaled - n_pos_one_scaled);
                _decompose(n_pos_zero_scaled, n_pos_zero_scaled_var_decomposition);
                _decompose(n_pos_one_scaled, n_pos_one_scaled_var_decomposition);
                eval(is_lesseq) = is_lesseq_tmp;        
                eval(abs_var) = d_abs;
                is_fair_tmp = (d_abs <= thr_scaled);
                eval(is_fair) = is_fair_tmp;
                d_fair = (is_fair_tmp) ? (thr_scaled - d_abs) : (d_abs - thr_scaled);
                eval(diff) = d_fair;
                _decompose(d_abs, abs_decomposition);
                _decompose(thr_scaled, thr_scaled_var_decomposition);
                _decompose(d_fair, diff_decomposition);
                eval(is_lesseq) = is_lesseq_tmp;
                std::cout << "E[hat_Y = 1 | A = 0] = " << std::fixed << std::setprecision(4) << (double(n_pos_zero) / double(n_sens_zero)) << std::endl;
                std::cout << "E[hat_Y = 1 | A = 1] = " << std::fixed << std::setprecision(4) << (double(n_pos_one) / double(n_sens_one)) << std::endl;
                // Print Demographic Parity difference
                std::cout << "Demographic Parity difference: "
                          << std::fixed << std::setprecision(4)
                          << std::abs(double(n_pos_zero) / double(n_sens_zero) - double(n_pos_one) / double(n_sens_one))
                          << std::endl;
                break;
            case 2: 
                // TPR: Y = 1
                is_lesseq_tpos_tmp = (n_pos_zero_tpos_scaled <= n_pos_one_tpos_scaled);
                d_abs_tpos = (is_lesseq_tpos_tmp) ? (n_pos_one_tpos_scaled - n_pos_zero_tpos_scaled) : (n_pos_zero_tpos_scaled - n_pos_one_tpos_scaled);
                is_fair_tpos_tmp = (d_abs_tpos <= thr_scaled_tpos);
                d_fair_tpos = (is_fair_tpos_tmp) ? (thr_scaled_tpos - d_abs_tpos) : (d_abs_tpos - thr_scaled_tpos);
                
                _decompose(n_pos_zero_tpos_scaled, n_pos_zero_tpos_scaled_var_decomposition);
                _decompose(n_pos_one_tpos_scaled, n_pos_one_tpos_scaled_var_decomposition);
                _decompose(d_abs_tpos, abs_tpos_decomposition);
                _decompose(thr_scaled_tpos, thr_scaled_tpos_var_decomposition);
                _decompose(d_fair_tpos, diff_tpos_decomposition);

                eval(is_lesseq_tpos) = is_lesseq_tpos_tmp;
                eval(abs_tpos_var) = d_abs_tpos;
                eval(is_fair_tpos) = is_fair_tpos_tmp;
                eval(diff_tpos) = d_fair_tpos;
                eval(is_lesseq_tpos) = is_lesseq_tpos_tmp;
                std::cout << "EqOd:TPR" << std::endl;
                std::cout << "    E[hat_Y = 1 | A = 0, Y = 1] = " << std::fixed << std::setprecision(4) << (double(n_pos_zero_tpos) / double(n_sens_zero_tneg)) << std::endl;
                std::cout << "    E[hat_Y = 1 | A = 1, Y = 1] = " << std::fixed << std::setprecision(4) << (double(n_pos_one_tpos) / double(n_sens_one_tpos)) << std::endl;
                // Print TPR difference
                std::cout << "    TPR difference: "
                          << std::fixed << std::setprecision(4)
                          << std::abs(double(n_pos_zero_tpos) / double(n_sens_zero_tpos) - double(n_pos_one_tpos) / double(n_sens_one_tpos))
                          << std::endl;

                // FPR: Y = 0
                is_lesseq_tneg_tmp = (n_pos_zero_tneg_scaled <= n_pos_one_tneg_scaled);
                d_abs_tneg = (is_lesseq_tneg_tmp) ? (n_pos_one_tneg_scaled - n_pos_zero_tneg_scaled) : (n_pos_zero_tneg_scaled - n_pos_one_tneg_scaled);
                is_fair_tneg_tmp = (d_abs_tneg <= thr_scaled_tneg);
                d_fair_tneg = (is_fair_tneg_tmp) ? (thr_scaled_tneg - d_abs_tneg) : (d_abs_tneg - thr_scaled_tneg);
                
                _decompose(n_pos_zero_tneg_scaled, n_pos_zero_tneg_scaled_var_decomposition);
                _decompose(n_pos_one_tneg_scaled, n_pos_one_tneg_scaled_var_decomposition);
                _decompose(d_abs_tneg, abs_tneg_decomposition);
                _decompose(thr_scaled_tneg, thr_scaled_tneg_var_decomposition);
                _decompose(d_fair_tneg, diff_tneg_decomposition);

                eval(is_lesseq_tneg) = is_lesseq_tneg_tmp;
                eval(abs_tneg_var) = d_abs_tneg;
                eval(is_fair_tneg) = is_fair_tneg_tmp;
                eval(diff_tneg) = d_fair_tneg;
                eval(is_lesseq_tneg) = is_lesseq_tneg_tmp;
                std::cout << "EqOd:FPR" << std::endl;
                std::cout << "    E[hat_Y = 1 | A = 0, Y = 0] = " << std::fixed << std::setprecision(4) << (double(n_pos_zero_tneg) / double(n_sens_zero_tneg)) << std::endl;
                std::cout << "    E[hat_Y = 1 | A = 1, Y = 0] = " << std::fixed << std::setprecision(4) << (double(n_pos_one_tneg) / double(n_sens_one_tneg)) << std::endl;
                // Print FPR difference
                std::cout << "    FPR difference: "
                          << std::fixed << std::setprecision(4)
                          << std::abs(double(n_pos_zero_tneg) / double(n_sens_zero_tneg) - double(n_pos_one_tneg) / double(n_sens_one_tneg))
                          << std::endl;

                // result
                is_fair_tmp = is_fair_tpos_tmp && is_fair_tneg_tmp; 
                eval(is_fair) = is_fair_tmp;
                break;
            case 3:
                is_lesseq_tmp = (residual_zero_scaled-residual_one_scaled <= 0);
                d_abs = (is_lesseq_tmp) ? (residual_one_scaled-residual_zero_scaled) : (residual_zero_scaled-residual_one_scaled);
                _decompose(n_pos_zero_scaled, n_pos_zero_scaled_var_decomposition);
                _decompose(n_pos_one_scaled, n_pos_one_scaled_var_decomposition);
                eval(is_lesseq) = is_lesseq_tmp;        
                eval(abs_var) = d_abs;
                is_fair_tmp = (d_abs <= thr_scaled);
                eval(is_fair) = is_fair_tmp;
                d_fair = (is_fair_tmp) ? (thr_scaled - d_abs) : (d_abs - thr_scaled);
                eval(diff) = d_fair;
                _decompose(d_abs, abs_decomposition);
                _decompose(thr_scaled, thr_scaled_var_decomposition);
                _decompose(d_fair, diff_decomposition);
                eval(is_lesseq) = is_lesseq_tmp;
                std::cout << "E[Y - hat_Y | A = 0] = " << std::fixed << std::setprecision(4) << (double(residual_zero) / double(n_sens_zero)) << std::endl;
                std::cout << "E[Y - hat_Y | A = 1] = " << std::fixed << std::setprecision(4) << (double(residual_one) / double(n_sens_one)) << std::endl;
                // Print Mean Residual Difference
                std::cout << "Mean Residual Difference: "
                          << std::fixed << std::setprecision(4)
                          << std::abs(double(residual_zero) / double(n_sens_zero) - double(residual_one) / double(n_sens_one))
                          << std::endl;
                break;
            case 4: 
                /*
                _decompose(n_pos_one_scaled, n_pos_one_scaled_var_decomposition);
                is_fair_tmp = (thr_scaled <= n_pos_one_scaled);
                eval(is_fair) = is_fair_tmp;
                d_fair = (is_fair_tmp) ? (n_pos_one_scaled - thr_scaled) : (thr_scaled - n_pos_one_scaled);
                eval(diff) = d_fair;
                _decompose(thr_scaled, thr_scaled_var_decomposition);
                _decompose(d_fair, diff_decomposition);
                */
                double mean_0 = double(residual_zero) / double(n_sens_zero);
                double mean_1 = double(residual_one) / double(n_sens_one);
                double var_0 = double(var_residual_zero) / double(n_sens_zero * n_sens_zero * (n_sens_zero - 1));
                double var_1 = double(var_residual_one) / double(n_sens_one * n_sens_one * (n_sens_one - 1));
                std::cout << "mean_0: " << mean_0 << ", var_0: " << var_0 << std::endl;
                std::cout << "mean_1: " << mean_1 << ", var_1: " << var_1 << std::endl;
                is_fair_tmp = compare_means(mean_0, mean_1, var_0, var_1, n_sens_zero, n_sens_one, alpha);
                eval(is_fair) = is_fair_tmp;
                break;
        }       
        
        
        
        // debug
        if(verbose) {
        std::cout << "pos_label = " << pos_label << std::endl;
        std::cout << "data size = " << data.size() << std::endl;
        std::cout << "  n_sens_zero = " << n_sens_zero << std::endl;
        std::cout << "  n_sens_one  = " << n_sens_one << std::endl;
        std::cout << "  n_sens_zero_tpos = " << n_sens_zero_tpos << std::endl;
        std::cout << "  n_sens_zero_tneg = " << n_sens_zero_tneg << std::endl;
        std::cout << "  n_sens_one_tpos  = " << n_sens_one_tpos << std::endl;
        std::cout << "  n_sens_one_tneg  = " << n_sens_one_tneg << std::endl;
        std::cout << "n_correct = " << n_correct << std::endl;
        std::cout << "n_pos = " << n_pos << std::endl;
        std::cout << "  n_pos_zero = " << n_pos_zero << std::endl;
        std::cout << "  n_pos_zero_scaled = " << n_pos_zero * n_sens_one << std::endl;
        std::cout << "  n_pos_zero_tpos = " << n_pos_zero_tpos << std::endl;
        std::cout << "  n_pos_zero_tpos_scaled = " << n_pos_zero_tpos_scaled << std::endl;
        std::cout << "  n_pos_one  = " << n_pos_one << std::endl;
        std::cout << "  n_pos_one_scaled = " << n_pos_one * n_sens_zero << std::endl;
        std::cout << "  n_pos_one_tpos = " << n_pos_one_tpos << std::endl;
        std::cout << "  n_pos_one_tpos_scaled = " << n_pos_one_tpos_scaled << std::endl;
        std::cout << "thr_scaled = " << thr_scaled  << std::endl;
        std::cout << "thr_scaled_tpos = " << thr_scaled_tpos  << std::endl;
        std::cout << "thr_scaled_tneg = " << thr_scaled_tneg  << std::endl;
        std::cout << "is_fair  = " << is_fair_tmp  << std::endl;
        std::cout << "is_fair_tpos  = " << is_fair_tpos_tmp  << std::endl;
        }
        float fair_val, fair_val_tneg;
        switch (fairness_mode) {
            case 1: // Demographic Parity
                fair_val = float(d_abs) / float(n_sens_one * n_sens_zero * scaling_thr);
                break;
            case 2: // EqOd
                fair_val = float(d_abs_tpos) / float(n_sens_one_tpos * n_sens_zero_tpos * scaling_thr);
                fair_val_tneg = float(d_abs_tneg) / float(n_sens_one_tneg * n_sens_zero_tneg * scaling_thr);
                break;
            case 3: // MRD
                fair_val = float(d_abs) / float(n_sens_one * n_sens_zero * scaling_thr);
                fair_val_tneg = 0; // not used
                break;
            case 4:
                fair_val = abs(float(residual_zero/n_sens_zero) - float(residual_one/n_sens_one));
                break;
        }
        if (is_fair_tmp) {
            std::cout << "proving fairness"  << std::endl;
        } else {
            std::cout << "result is unfair" << std::endl;
        }
        /*
        if (is_fair_tmp) {
            std::cout << "proving fairness: " << fair_val << std::endl;
            if (fairness_mode == 2) {
                std::cout << "proving fairness: " << fair_val_tneg << std::endl;
            }
        } else {
            std::cout << "proving unfairness: " << fair_val  << std::endl;
            if (fairness_mode == 2) {
                std::cout << "proving unfairness: " << fair_val_tneg << std::endl;
            }
        }*/
        std::cout << "generating a circuit and assigning wire values..." << std::endl;
    }

    void generate_r1cs_constraints() {
        std::cout << "generate_r1cs_constraints()" << std::endl;
        _general_constraints();
        _decomposition_constraints();
        _hash_constraints();
        _path_constraints();
        _nodes_multiset_constraints();
    }

    void generate_r1cs_witness() {
        std::cout << "generate_r1cs_witness()" << std::endl;
        _general_witness();
        _decomposition_witness();
        _hash_witness();
        _path_witness();
        _nodes_multiset_witness();
    }
};


#endif //ZKDT_DT_BATCH_GADGET_H
