#include <iostream>
#include <algorithm>
#include <stdexcept>
#include <cstdlib>
#include <numeric>
#include <chrono>
#include <iomanip>
#include <sstream>
#include <string>
#include <random>

#include <libsnark/common/default_types/r1cs_gg_ppzksnark_pp.hpp>
#include <libsnark/relations/constraint_satisfaction_problems/r1cs/examples/r1cs_examples.hpp>
#include <libsnark/zk_proof_systems/ppzksnark/r1cs_gg_ppzksnark/r1cs_gg_ppzksnark.hpp>
#include <openssl/sha.h>

#include "gadgets/dt_batch_gadget.h"
#include "DT/DT.h"
#include "hash.h"
#include "DT/compile.h"
#include "gadgets/swifft.h"
#include "ZKP/groth16.h"

using namespace libsnark;

bool verbose = false;
bool dryrun = false;
unsigned max_batch_size = 100; // change this for experiments

// Function to flatten a 2D vector into a single byte stream
std::vector<unsigned char> flattenVector(const std::vector<std::vector<unsigned>>& vec) {
    std::vector<unsigned char> flatData;
    for (const auto& row : vec) {
        for (const auto& elem : row) {
            flatData.push_back(static_cast<unsigned char>(elem & 0xFF));  // Extract least significant byte
        }
    }
    return flatData;
}

void appendUnsignedArray(std::vector<unsigned char>& vec, const unsigned* arr, size_t length) {
    for (size_t i = 0; i < length; ++i) {
        unsigned value = arr[i];
        unsigned char bytes[sizeof(unsigned)];
        std::memcpy(bytes, &value, sizeof(unsigned));  // Convert unsigned to bytes

        // Append bytes to vector
        vec.insert(vec.end(), bytes, bytes + sizeof(unsigned));
    }
}

// Function to compute SHA-256 hash of byte data
std::string sha256(const std::vector<unsigned char>& data) {
    unsigned char hash[SHA256_DIGEST_LENGTH];  
    SHA256_CTX sha256;
    SHA256_Init(&sha256);
    SHA256_Update(&sha256, data.data(), data.size());
    SHA256_Final(hash, &sha256);

    // Convert hash to hex string
    std::stringstream ss;
    for (unsigned char c : hash) {
        ss << std::hex << std::setw(2) << std::setfill('0') << (int)c;
    }
    return ss.str();
}

// Convert first 4 bytes (32 bits) of hash to an integer seed
uint32_t hashToSeed(const std::string& hash) {
    if (hash.size() < 4) {
        throw std::runtime_error("Hash is too short to generate a seed");
    }
    return static_cast<uint32_t>(
        (static_cast<uint8_t>(hash[0]) << 24) |
        (static_cast<uint8_t>(hash[1]) << 16) |
        (static_cast<uint8_t>(hash[2]) << 8) |
        (static_cast<uint8_t>(hash[3])));
}

void printData(const std::vector<std::vector<unsigned>>& vec) {
    for (const auto& row : vec) {
        for (const auto& elem : row) {
            std::cout << elem << " ";
        }
        std::cout << std::endl;  // Newline after each row
    }
}

// tree size is exponential in depth
template <typename ppT>
void test_synthetic_dt_batch(int depth = 5, int batch_size = 128) {
    typedef libff::Fr<ppT> FieldT;
    std::cout << "Generate R1CS for synthetic DT: " << std::endl;
    std::cout << "Depth: " << depth << std::endl;
    std::cout << "Batch size: " << batch_size << std::endl;
    std::cout << "Tree size: " << (1 << (depth + 1)) << std::endl;

    unsigned threshold = 100;
    int n_vars = depth + 1;
    DT *dt = full_binary_dt(depth, threshold);
    dt->gather_statistics();

    std::vector<std::vector<unsigned>> data;
    std::vector<unsigned> labels;

    for (int i = 0; i < batch_size; ++i) {
        data.push_back(std::vector<unsigned>());
        for (int j = 0; j < n_vars; ++j) {
            unsigned x = rand() % threshold;
            data[i].push_back(x);
        }
        labels.push_back(rand() % depth);
    }

    protoboard <FieldT> pb;
    FieldT coef = rand(), challenge_point = rand();
    DTBatchGadget <FieldT> dtBatchGadget = DTBatchGadget<FieldT>(pb, *dt, data, labels, coef, challenge_point, "dt_batch_gadget");
    dtBatchGadget.generate_r1cs_constraints();
    dtBatchGadget.generate_r1cs_witness();
    std::cout << "N_constraints:  " << pb.num_constraints() << std::endl;
    std::cout << "N_variables: " << pb.num_variables() << std::endl;
    std::cout << "Is circuit satisfied?: " << pb.is_satisfied() << std::endl;
    if (pb.is_satisfied()) {
        run_r1cs_gg_ppzksnark<ppT>(pb, "proof_2");
    } else {
        std::cout << "Circuit is not satisfied, abort" << std::endl;
    }

    
}

void parse_line(const std::string& line, std::vector<unsigned>& sample, std::vector<unsigned>& labels, unsigned n_values, unsigned multi, bool scale_label=false) {
    sample.clear();
    float v;
    std::stringstream ss(line);
    for (int i = 0; i < n_values; ++i) {
        ss >> v;
        int v_tmp = v * multi;
        // disallow negative value
        /* if (v < 0) {
            v = -v;
        } */
        // std::cout << v << std::endl;
        sample.push_back((unsigned) (v_tmp + INT32_MAX/4)); // shift to positive values
        if (ss.peek() == ',') {
            ss.ignore();
        }
    }
    ss >> v;
    if (scale_label) {
        v = v * multi; // scale label
    }
    labels.push_back((unsigned) v); 
}

void read_dataset(const std::string& filename, std::vector<std::vector<unsigned>>& data, std::vector<unsigned>& labels, unsigned multi, unsigned max_batch_size, bool scale_label=false, unsigned offset=0) {
    std::ifstream f (filename);
    std::string line;

    if (f.is_open()) {
        // how many values do we have?
        unsigned n_values = 0;
        for (unsigned i = 0; i < offset; ++i) {
            if (!std::getline(f, line)) {
                throw std::runtime_error("Not enough lines in data file to skip offset");
            }
        }
        std::getline(f, line);
        for (int i = 0; i < line.size(); ++i) {
            if (line[i] == ',') {
                n_values++;
            }
        }

        unsigned n_lines = 0;
        data.push_back(std::vector<unsigned>());        
        parse_line(line, data[n_lines++], labels, n_values, multi, scale_label);

        while (std::getline(f, line)) {
            if (line.size() < 5) {
                continue;
            }
            data.push_back(std::vector<unsigned>());
            parse_line(line, data[n_lines++], labels, n_values, multi, scale_label);
            if (n_lines == max_batch_size) {
                break;
            }
        }

        f.close();
    } else {
        throw std::runtime_error("Can not open data file");
    }
}

template <typename ppT>
void test_real_dt_batch(unsigned selector, unsigned fairness_mode){
    typedef libff::Fr<ppT> FieldT;
    // change parameters below
    float thr = 0.8;
    float subset_ratio = 1; // deprecated
    float alpha = 0.05; // significance level for statistical tests
    bool subset = false; // use subset of data for testing

    std::cout << "Using test dataset: ";
    std::string names[11] = {"Iris", "Wine", "Abalone", "Forest", "Breast-cancer-wisconsin", "Spambase", "Test", "Adult", "ACS", "Credit", "Credit (OFL)"};
    std::cout << names[selector] << std::endl;

    DT *dt;
    std::vector <std::vector<unsigned>> data;
    std::vector<unsigned> labels;
    unsigned sensitive = 0;
    unsigned thr_sensitive = INT32_MAX/4; // offset
    // define which class_label value corresponds to "positive" prediction
    unsigned pos_label = 0;
    unsigned scale = 1;

    switch (selector) {
        case 0:
            scale = 100;
            dt = _read_dt_from_file("../Model/Iris_dt.txt");
            read_dataset("../Model/iris.data", data, labels, scale, max_batch_size);
            break;
        case 1:
            scale = 10000;
            dt = _read_dt_from_file("../Model/wine_dt.txt");
            read_dataset("../Model/wine.data", data, labels, scale, max_batch_size);
            break;
        case 2:
            scale = 10000;
            dt = _read_dt_from_file("../Model/Abalone_dt.txt");
            read_dataset("../Model/abalone.data", data, labels, scale, max_batch_size);
            break;
        case 3:
            scale = 100;
            dt = _read_dt_from_file("../Model/Forest_dt.txt");
            read_dataset("../Model/covtype.data", data, labels, scale, max_batch_size);
            break;
        case 4:
            dt = _read_dt_from_file("../Model/breast-cancer-wisconsin_dt.txt");
            read_dataset("../Model/breast-cancer-wisconsin.data", data, labels, 1, max_batch_size);
            break;
        case 5:
            scale = 1000;
            dt = _read_dt_from_file("../Model/spambase_dt.txt");
            read_dataset("../Model/spambase.data", data, labels, scale, max_batch_size);
            break;
        case 6: // new test data
            sensitive = 2; // test
            thr_sensitive += 3; // test
            pos_label = 1;
            dt = _read_dt_from_file("../Model/test_dt.txt");
            read_dataset("../Model/test0.data", data, labels, 1, max_batch_size);
            break;
        case 7: // adult
            scale = 10;
            sensitive = 9; 
            thr_sensitive += 1 * scale;
            pos_label = 1;
            dt = _read_dt_from_file("../Model/adult_dt_fair.txt");
            read_dataset("../datasets/adult_syn_CTGAN_13500_13500_epoch300_class.csv", data, labels, scale, max_batch_size, false, 1);
            break;
        case 8: // acs  
            scale = 10;
            sensitive = 8;
            thr_sensitive += 1 * scale; 
            pos_label = 1;
            dt = _read_dt_from_file("../Model/acs_dt_fair.txt");
            read_dataset("../datasets/acs_syn_CTGAN_15000_15000_epoch300_class.csv", data, labels, scale, max_batch_size, false, 1);
            break;
        case 9: // default credit
            scale = 10;
            sensitive = 1;
            thr_sensitive += 1 * scale; 
            pos_label = 1; 
            dt = _read_dt_from_file("../Model/credit_dt_fair.txt");
            read_dataset("../datasets/credit_syn_CTGAN_9000_9000_epoch300.csv", data, labels, scale, max_batch_size, false, 1);
            break;
        case 10: // default credit (t-test)
            scale = 10;
            sensitive = 1;
            thr_sensitive += 1 * scale; 
            pos_label = 1 * scale;
            dt = _read_dt_from_file("../Model/credit_dt_pr_fair.txt");
            read_dataset("../datasets/default_credit.csv", data, labels, scale, max_batch_size, true, 1);
            break;
    }
    switch (fairness_mode) {
        case 0: 
            std::cout << "skipping fairness check" << std::endl;
            break;
        case 1:
            std::cout << "checking Demographic Parity with thr = " << thr << std::endl;
            break;
        case 2:
            std::cout << "checking Equalized Odds with thr = " << thr << std::endl;
            break;
        case 3:
            std::cout << "checking Mean Residual Difference with thr = " << thr << std::endl;
            break;
        case 4:
            std::cout << "checking Mean Residual Difference with significance level alpha = " << alpha << std::endl;
            break;
    }

    /*
    for (int i = 0; i < data.size(); ++i) {
        labels.push_back(rand() % 2); // some synthetic labels, it's better to read it from the file.
    }
    */
    if(verbose){
        dt->print_dt(dt->root, 0); 
    }
    unsigned subset_size = data.size() * subset_ratio;
    std::cout << "test data size: " << data.size() << std::endl;
    // std::cout << "subset size : " << subset_size << std::endl;
    std::cout << "attribute size: " << data[0].size() << std::endl;

    dt->gather_statistics();

    std::cout << "Tree height: " << dt->root->height - 1 << std::endl;
    // std::cout << "Tree size: " << dt->root->size << std::endl;
    // std::cout << "Non-leaf size: " << dt->root->non_leaf_size << std::endl;
    
    protoboard <FieldT> pb;
    FieldT coef = rand(), challenge_point = rand();

    std::cout << "========= Prover precomputation" << std::endl;
    DTBatchGadget <FieldT> dtBatchGadget = DTBatchGadget<FieldT>(pb, *dt, data, labels, coef, challenge_point, 
                                                                thr, sensitive, thr_sensitive, fairness_mode, 
                                                                pos_label, scale, alpha, verbose, "dtBatchGadget");

    dtBatchGadget.generate_r1cs_constraints();
    dtBatchGadget.generate_r1cs_witness();
    std::cout << "N_constraints: " << pb.num_constraints() << std::endl;
    std::cout << "N_variables: " << pb.num_variables() << std::endl;
    
    std::cout << "========= Done" << std::endl;
    // only check constraints are satisfied without running SNARK
    if (pb.is_satisfied()) {
        std::cout << "circuit is satisfied: " << pb.is_satisfied() << ": generate proof" << std::endl;
        if (!dryrun) {
            run_r1cs_gg_ppzksnark<ppT>(pb, "proof");
        }
    } else {
        std::cout << "circuit is NOT satisfied: " << pb.is_satisfied() << ": abort" << std::endl;
    }
    
    // printf("=============== Hash: %3.f seconds\n", elapsed_hash.count());
}


int main(int argc, char* argv[]) {
    std::ios::sync_with_stdio(true);
    unsigned selector = 6; // default dataset
    unsigned fairness_mode = 0; // default fairness mode: skip
    for (int i = 1; i < argc; ++i) {
        if (std::string(argv[i]) == "--verbose" || std::string(argv[i]) == "-v") {
            verbose = true;
        }
        if (std::string(argv[i]) == "--dry" || std::string(argv[i]) == "-d") {
            dryrun = true;
        }
        if (std::strcmp(argv[i], "--testsize") == 0 && i + 1 < argc) {
            max_batch_size = std::stoi(argv[i + 1]);
            ++i; // skip the value
        }
        if (std::strcmp(argv[i], "--data") == 0 && i + 1 < argc) {
            std::string data_arg = argv[i + 1];
            if (data_arg == "test") {
                selector = 6;
            } else if (data_arg == "adult") {
                selector = 7;
            } else if (data_arg == "acs") {
                selector = 8;    
            } else if (data_arg == "credit") {
                selector = 9;
            } else if (data_arg == "credit-ofl") {
                selector = 10;
            } else {
                std::cerr << "Unknown --data argument: " << data_arg << std::endl;
                exit(1);
            }
            ++i; // skip the value
        }
        if (std::strcmp(argv[i], "--metric") == 0 && i + 1 < argc) {
            std::string metric_arg = argv[i + 1];
            if (metric_arg == "skip") {
                fairness_mode = 0;
            } else if (metric_arg == "dp") {
                fairness_mode = 1;
            } else if (metric_arg == "eqod") {
                fairness_mode = 2;
            } else if (metric_arg == "mrd") {
                fairness_mode = 3;
            } else if (metric_arg == "t-test") {
                fairness_mode = 4;
            } else {
                std::cerr << "Unknown --metric argument: " << metric_arg << std::endl;
                exit(1);
            }
            ++i; // skip the value
        }
    }

    default_r1cs_gg_ppzksnark_pp::init_public_params();
    swifft::init_swifft();
    //test_synthetic_dt_batch<default_r1cs_gg_ppzksnark_pp>();
    auto start = std::chrono::high_resolution_clock::now();
    printf("test_real_dt_batch started\n");
    test_real_dt_batch<default_r1cs_gg_ppzksnark_pp>(selector, fairness_mode);
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed = end - start;
    printf("test_real_dt_batch ended\n");
    printf("Total elapsed time: %1.f seconds\n", elapsed.count());
}