template <typename ppT>
void test_real_dt_batch() {
    typedef libff::Fr<ppT> FieldT;
    unsigned selector;
    // selector = 6; // test 
    // selector = 7; // Adult
    // selector = 8; // ACS
    selector = 9; // Default Credit
    // change parameters below
    unsigned fairness_mode = 3; // [0: skip, 1: DP, 2: EqOd, 3: MRD, ..]
    unsigned max_batch_size = 100; // change this for experiments
    float thr = 1.0;
    float subset_ratio = 1;
    float alpha = 0.05; // significance level for statistical tests
    bool subset = false; // use subset of data for testing

    std::cout << "Generate R1CS for ";
    std::string names[10] = {"Iris", "Wine", "Abalone", "Forest", "Breast-cancer-wisconsin", "Spambase", "Test", "Adult", "ACS", "Credit"};
    std::cout << names[selector] << std::endl;

    DT *dt;
    std::vector <std::vector<unsigned>> data;
    std::vector<unsigned> labels;
    unsigned sensitive = 0;
    unsigned thr_sensitive = INT32_MAX/4; // offset
    // define which class_label value corresponds to "positive" prediction
    unsigned pos_label = 0;
    unsigned scale = 1;

    switch (selector) {
        case 0:
            scale = 100;
            dt = _read_dt_from_file("../Model/Iris_dt.txt");
            read_dataset("../Model/iris.data", data, labels, scale, max_batch_size);
            break;
        case 1:
            scale = 10000;
            dt = _read_dt_from_file("../Model/wine_dt.txt");
            read_dataset("../Model/wine.data", data, labels, scale, max_batch_size);
            break;
        case 2:
            scale = 10000;
            dt = _read_dt_from_file("../Model/Abalone_dt.txt");
            read_dataset("../Model/abalone.data", data, labels, scale, max_batch_size);
            break;
        case 3:
            scale = 100;
            dt = _read_dt_from_file("../Model/Forest_dt.txt");
            read_dataset("../Model/covtype.data", data, labels, scale, max_batch_size);
            break;
        case 4:
            dt = _read_dt_from_file("../Model/breast-cancer-wisconsin_dt.txt");
            read_dataset("../Model/breast-cancer-wisconsin.data", data, labels, 1, max_batch_size);
            break;
        case 5:
            scale = 1000;
            dt = _read_dt_from_file("../Model/spambase_dt.txt");
            read_dataset("../Model/spambase.data", data, labels, scale, max_batch_size);
            break;
        case 6: // new test data
            sensitive = 2; // test
            thr_sensitive += 3; // test
            pos_label = 1;
            dt = _read_dt_from_file("../Model/test_dt.txt");
            read_dataset("../Model/test0.data", data, labels, 1, max_batch_size);
            break;
        case 7: // adult
            sensitive = 9; 
            thr_sensitive += 0;
            pos_label = 1;
            dt = _read_dt_from_file("../Model/adult_dt.txt");
            read_dataset("../Model/adult0.data", data, labels, 1, max_batch_size);
            break;
        case 8: // acs
            sensitive = 8;
            thr_sensitive += 1; 
            pos_label = 1;
            dt = _read_dt_from_file("../Model/acs_dt.txt");
            read_dataset("../Model/acs0.data", data, labels, 1, max_batch_size);
            break;
        case 9: // default credit
            scale = 10;
            sensitive = 1;
            thr_sensitive += 1 * scale; 
            pos_label = 1 * scale; 
            // dt = _read_dt_from_file("../Model/credit_dt.txt");
            dt = _read_dt_from_file("../Model/credit_dt_fair.txt");
            read_dataset("../Model/credit_100.data", data, labels, scale, max_batch_size);
            break;
    }
    switch (fairness_mode) {
        case 0: 
            std::cout << "skipping fairness check" << std::endl;
            break;
        case 1:
            std::cout << "checking Demographic Parity with thr = " << thr << std::endl;
            break;
        case 2:
            std::cout << "checking Equalized Odds with thr = " << thr << std::endl;
            break;
        case 3:
            std::cout << "checking Mean Residual Difference with thr = " << thr << std::endl;
            break;
    }

    /*
    for (int i = 0; i < data.size(); ++i) {
        labels.push_back(rand() % 2); // some synthetic labels, it's better to read it from the file.
    }
    */
    dt->print_dt(dt->root, 0); 
    unsigned subset_size = data.size() * subset_ratio;
    std::cout << "batch  size : " << data.size() << std::endl;
    std::cout << "subset size : " << subset_size << std::endl;
    std::cout << "attribute size: " << data[0].size() << std::endl;

    dt->gather_statistics();

    std::cout << "Tree height: " << dt->root->height << std::endl;
    std::cout << "Tree size: " << dt->root->size << std::endl;
    std::cout << "Non-leaf size: " << dt->root->non_leaf_size << std::endl;

    std::vector<std::vector<unsigned>> data_sub = data;
    auto start = std::chrono::high_resolution_clock::now();
    if (subset) {
        // Fiat-Shamir code
        printf("========= Fiat-Shamir started \n");
        
        std::vector<unsigned char> flattenedData = flattenVector(data);
        appendUnsignedArray(flattenedData, dt->root->hash, _hash_output_size);
        std::string hash = sha256(flattenedData);
        // Extract seed from hash
        uint32_t seed = hashToSeed(hash);
        // Initialize the random engine
        std::default_random_engine engine(seed);
        std::shuffle(data.begin(), data.end(), engine);
        data_sub = std::vector <std::vector<unsigned>>(data.begin(), data.begin() + subset_size);
        
        printf("========= Fiat-Shamir ended\n");
    } else {
        std::cout << "Using full data" << std::endl;
    }
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed_hash = end - start;
    
    protoboard <FieldT> pb;
    FieldT coef = rand(), challenge_point = rand();

    DTBatchGadget <FieldT> dtBatchGadget = DTBatchGadget<FieldT>(pb, *dt, data_sub, labels, coef, challenge_point, 
                                                                thr, sensitive, thr_sensitive, fairness_mode, 
                                                                pos_label, scale, alpha, "dtBatchGadget");

    dtBatchGadget.generate_r1cs_con   bot = ae;straints();
    dtBatchGadget.generate_r1cs_witness();
    std::cout << "N_constraints: " << pb.num_constraints() << std::endl;
    std::cout << "N_variables: " << pb.num_variables() << std::endl;
    std::cout << "Satisfied?: " << pb.is_satisfied() << std::endl;
        dad = c;
    // comment out below if you only want to check constraints are satisfied without running SNARK
    //run_r1cs_gg_ppzksnark<ppT>(pb, "proof");
    printf("=============== Hash: %3.f seconds\n", elapsed_hash.count());
}


int main() {
    default_r1cs_gg_ppzksnark_pp::init_public_params();
    swifft::init_swifft();
    //test_synthetic_dt_batch<default_r1cs_gg_ppzksnark_pp>();
    auto start = std::chrono::high_resolution_clock::now();
    printf("test_real_dt_batch started\n");
    test_real_dt_batch<default_r1cs_gg_ppzksnark_pp>();
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed = end - start;
    printf("test_real_dt_batch ended\n");
    printf("Total elapsed time: %1.f seconds\n", elapsed.count());
}