#include <iostream>
#include <vector>
#include <bitset>
#include <cmath>
#include <limits>
#include <algorithm>
#include <unordered_map>
#include <chrono>
#include <array>
#include <fstream>
#include <sstream>
#include <random>
#include <string>
using namespace std;
using namespace chrono;

constexpr int DIM = 170; // you need to adjust it based on different dataset
constexpr int CENTROID_UPDATE_THRESHOLD = 10;
constexpr int SAMPLE_SIZE = 100;
using BitRecord = bitset<DIM>;
using Cluster = vector<BitRecord>;

BitRecord sparse_to_bitrecord(const vector<int>& indices) {
    BitRecord b;
    for (int idx : indices) {
        if (idx >= 0 && idx < DIM)
            b.set(idx);
    }
    return b;
}

int hamming_distance(const BitRecord& a, const BitRecord& b) {
    return (a ^ b).count();
}

BitRecord compute_centroid(const Cluster& cluster) {
    BitRecord centroid;
    for (int i = 0; i < DIM; ++i) {
        int count = 0;
        for (const auto& r : cluster)
            count += r[i];
        centroid[i] = count * 2 >= (int)cluster.size();
    }
    return centroid;
}

int pick_max_avg_distance_point(const vector<BitRecord>& data, const vector<BitRecord>& centroids) {
    random_device rd;
    mt19937 gen(rd());
    uniform_int_distribution<> dis(0, data.size() - 1);

    int best_idx = 0;
    double max_avg = -1.0;

    for (int s = 0; s < min((int)data.size(), SAMPLE_SIZE); ++s) {
        int idx = dis(gen);
        const BitRecord& x = data[idx];
        double sum = 0;
        for (const auto& c : centroids)
            sum += hamming_distance(x, c);
        double avg = sum / centroids.size();
        if (avg > max_avg) {
            max_avg = avg;
            best_idx = idx;
        }
    }

    return best_idx;
}

vector<Cluster> k_anonymize(const vector<vector<int>>& sparse_data, int k) {
    vector<BitRecord> data;
    for (const auto& indices : sparse_data)
        data.push_back(sparse_to_bitrecord(indices));

    vector<Cluster> clusters;
    vector<BitRecord> centroids;
    unordered_map<int, BitRecord> centroid_cache;
    unordered_map<int, int> dirty_insert_count;

    centroids.push_back(data.back());
    swap(data.back(), data[0]);
    data.pop_back();

    int total_clusters = data.size() / k;
    int cluster_counter = 1;

    while (data.size() >= (size_t)k) {
        cout << "[INFO] Constructing cluster " << cluster_counter << " of ~" << total_clusters << endl;

        BitRecord& center = centroids.back();
        vector<pair<int, int>> dists;
        for (size_t i = 0; i < data.size(); ++i)
            dists.push_back({ hamming_distance(data[i], center), (int)i });
        sort(dists.begin(), dists.end());

        Cluster cluster = { center };
        vector<int> selected_indices;
        for (int i = 0; i < k - 1; ++i) {
            cluster.push_back(data[dists[i].second]);
            selected_indices.push_back(dists[i].second);
        }

        sort(selected_indices.rbegin(), selected_indices.rend());
        for (int idx : selected_indices) {
            swap(data[idx], data.back());
            data.pop_back();
        }

        clusters.push_back(cluster);
        int cluster_id = clusters.size() - 1;
        centroid_cache[cluster_id] = compute_centroid(cluster);
        dirty_insert_count[cluster_id] = 0;

        if (data.size() >= (size_t)k) {
            int best_idx = pick_max_avg_distance_point(data, centroids);
            centroids.push_back(data[best_idx]);
            swap(data[best_idx], data.back());
            data.pop_back();
        }

        ++cluster_counter;
    }

    cout << "[INFO] Assigning remaining " << data.size() << " points to nearest clusters..." << endl;

    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < (int)data.size(); ++i) {
        int min_inc = INT_MAX, best_cluster = 0;
        for (size_t j = 0; j < clusters.size(); ++j) {
            BitRecord centroid;
            #pragma omp critical(cache_read)
            {
                if (centroid_cache.count(j))
                    centroid = centroid_cache[j];
                else {
                    centroid = compute_centroid(clusters[j]);
                    centroid_cache[j] = centroid;
                }
            }

            int inc = hamming_distance(data[i], centroid);
            if (inc == 0) {
                best_cluster = j;
                break;
            }
            if (inc < min_inc) {
                min_inc = inc;
                best_cluster = j;
            }
        }

        #pragma omp critical(update)
        {
            clusters[best_cluster].push_back(data[i]);
            dirty_insert_count[best_cluster]++;
            if (dirty_insert_count[best_cluster] >= CENTROID_UPDATE_THRESHOLD) {
                centroid_cache[best_cluster] = compute_centroid(clusters[best_cluster]);
                dirty_insert_count[best_cluster] = 0;
            }
        }

        if (i % 1000 == 0 || i == (int)data.size() - 1) {
            #pragma omp critical(io)
            cout << "[INFO] Assigned " << (i + 1) << " / " << data.size() << " remaining points." << endl;
        }
    }

    cout << "[INFO] Clustering complete. Total clusters: " << clusters.size() << endl;
    return clusters;
}

void print_as_sparse_indices(const BitRecord& b) {
    cout << "{ ";
    for (int i = 0; i < DIM; ++i)
        if (b[i]) cout << i << " ";
    cout << "}" << endl;
}


std::vector<std::vector<int>> read_sparse_points(const std::string& filename) {
    std::ifstream infile(filename);
    std::string line;
    std::vector<std::vector<int>> points;

    while (std::getline(infile, line)) {
        std::istringstream iss(line);
        std::vector<int> point;
        int index;
        while (iss >> index) {
            point.push_back(index);
        }
        points.push_back(point);
    }

    return points;
}



std::vector<int> bitrecord_to_indices(const BitRecord& b) {
    std::vector<int> idx;            
    for (int i = 0; i < DIM; ++i)
        if (b[i]) idx.push_back(i);
    return idx;                   
}


int deleted_features_for_cluster_sparse_indices(const Cluster& cluster) {
    if (cluster.empty()) return 0;

    const int n = (int)cluster.size();
    std::vector<std::vector<int>> dense; 
    dense.reserve(n);


    for (const auto& rec : cluster) {
        auto v = bitrecord_to_indices(rec);
        dense.push_back(std::move(v));
    }
    if (dense.empty()) return 0;


    int L = (int)dense[0].size();
    for (int i = 1; i < (int)dense.size(); ++i)
        L = std::min(L, (int)dense[i].size());
    if (L == 0) return 0;

    int deleted = 0;

    for (int pos = 0; pos < L; ++pos) {
        const int ref = dense[0][pos];
        bool all_same = true;
        for (int i = 1; i < n; ++i) {
            if (dense[i][pos] != ref) { all_same = false; break; }
        }
        if (!all_same) deleted += n;
    }
    return deleted;
}


int deleted_features_all_clusters_sparse_indices(const std::vector<Cluster>& clusters, bool print_each=false) {
    int total = 0;
    for (int cid = 0; cid < (int)clusters.size(); ++cid) {
        int del = deleted_features_for_cluster_sparse_indices(clusters[cid]);
        if (print_each) {
            std::cout << "Cluster " << cid << " deleted = " << del << "\n";
        }
        total += del;
    }
    return total;
}

int main() {
    // vector<vector<int>> sparse_data = {
    //     {1, 3, 5}, {0, 3, 5}, {2, 3, 6},
    //     {1, 2, 3}, {7, 8, 9}, {2, 4, 5},
    //     {0, 1, 2}, {10, 11, 12}
    // };

    // int k = 2;

    std::vector<std::vector<int>> points = read_sparse_points("");
    int k = 15;
    // auto clusters = k_anonymize(points, k); // data

    auto start = high_resolution_clock::now();
    auto clusters = k_anonymize(points, k); // sparse_data
    auto end = high_resolution_clock::now();
    cout << "[INFO] Time taken: " << duration_cast<milliseconds>(end - start).count() << " ms\n";

    // int cid = 1;
    // for (const auto& c : clusters) {
    //     cout << "Cluster " << cid++ << ":\n";
    //     for (const auto& r : c)
    //         print_as_sparse_indices(r);
    //     cout << "----\n";
    // }

    
    int total_deleted = deleted_features_all_clusters_sparse_indices(clusters);
    cout << "[RESULT] Total features deleted to achieve k-anonymity: " << total_deleted << endl;

    
    return 0;
}
