#include <omp.h>

#include <chrono>
#include <iostream>
#include <iomanip>
#include <vector>

#include <Eigen/Dense>
#include "cnpy.h"

#include "serialization.cpp"

// using namespace std;
using namespace Eigen;

std::vector<std::vector<int>> build_epsilon_graph_here(std::string npz_path, float eps) {
    if (eps <= 0) {
      std::cerr << "Need ε to build similarity graph!" << std::endl;
      throw;
    }

    std::cout << "Loading " << npz_path << "...\n";
    cnpy::npz_t data = cnpy::npz_load(npz_path);

    cnpy::NpyArray embeddings = data["embeddings"];
    cnpy::NpyArray labels = data["labels"];
    size_t N = embeddings.shape[0];
    size_t D = embeddings.shape[1];
    std::cout << "Loaded embeddings: " << N << " × " << D << std::endl;

    assert(labels.word_size == 4);
    assert(embeddings.word_size == 4);
    std::vector<int> labels_vec = {labels.data<int>(), labels.data<int>() + labels.shape[0]};
    // Map<Matrix<float, Dynamic, Dynamic, RowMajor>> X(embeddings.data<float>(), N, D);
    Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> X(N, D);
    std::memcpy(X.data(), embeddings.data<float>(), N * D * sizeof(float));
    
    // Normalize for cosine similarity
    X.rowwise().normalize();

    int nthreads = omp_get_max_threads();
    constexpr int block_size = 3000;  // tune based on cache and memory
    
    std::cout << "Building ε-graph with ε = " << eps << " using " << nthreads << " threads and block size " << block_size << "..." << std::endl;
    
    std::vector<std::vector<std::vector<int>>> adj_local(nthreads, std::vector<std::vector<int>>(N));
    std::vector<int64_t> contrib_local(nthreads);

    const int num_blocks = (N + block_size - 1) / block_size;
    const int num_pairs = num_blocks * (num_blocks - 1) / 2 + num_blocks;
    std::cout << "Have " << num_blocks << " blocks with " << num_pairs << "pairs!" << std::endl;

    std::atomic_int count = 0;

    const int output_step = num_pairs / 1000;

    std::cout << "!\n!\n";
    auto start = std::chrono::high_resolution_clock::now();

    // === Blocked computation ===
    // #pragma omp parallel
    #pragma omp parallel
    {
        int tid = omp_get_thread_num();
        auto &adj_thread = adj_local[tid];
        
        // Thread-local result buffer
        MatrixXf S_local(block_size, block_size);
        // MatrixXb B_local(block_size, block_size);

        // loop over block pairs
        #pragma omp for collapse(2) schedule(dynamic, 1)
        for (int bi = 0; bi < (int)N; bi += block_size) {
            for (int bj = 0; bj < (int)N; bj += block_size) {
                if (bi > bj) continue;

                int i_end = std::min<int>(bi + block_size, N);
                int j_end = std::min<int>(bj + block_size, N);
                int rows = i_end - bi;
                int cols = j_end - bj;

                // compute similarity block: (i_end-bi) × (j_end-bj)
                auto S = S_local.topLeftCorner(rows, cols);
                S.noalias() = X.middleRows(bi, rows) * X.middleRows(bj, cols).transpose();
                // MatrixXb B = S > eps;

                // scan and store edges above threshold
                for (int j = 0; j < cols; j++) {
                    int v = bj + j;
                    int end_i = (bi == bj) ? j : rows;

                    for (int i = 0; i < end_i; i++) {

                        if (S(i, j) >= eps) {
                            int u = bi + i;
                            adj_thread[u].push_back(v);
                            adj_thread[v].push_back(u);
                        }
                    }

                }
                // for (int i = 0; i < rows; i++) {
                //     int u = bi + i;
                //     int start_j = (bi == bj) ? i + 1 : 0;
                    
                //     for (int j = start_j; j < cols; j++) {
                //         if (S(i, j) >= eps) {
                //             int v = bj + j;
                //             adj_thread[u].push_back(v);
                //             if (u != v) {
                //                 adj_thread[v].push_back(u);
                //             }
                //         }
                //     }
                // }

                if (++count % output_step == 0) {
                    auto end = std::chrono::high_resolution_clock::now();

                    auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
                    double proj = elapsed.count() / double(count) * num_pairs / 60000.0;
                    double rem = elapsed.count() / double(count) * (num_pairs - count) / 60000.0;

                    #pragma omp critical 
                    {
                        std::cout << "\033[A\033[K";
                        std::cout << "\033[A\033[K";
                        std::cout << 100. * count / num_pairs << "% of pairs done\n";
                        std::cout << "Elapsed time: " << elapsed.count() << "ms -> projected total: " << proj << "min, remaining: " << rem << "min\n";
                    }
                }
            }
        }
    }

    std::cout << "Merging adjacency lists..." << std::endl;

    // === Merge adjacency lists across threads ===
    std::vector<std::vector<int>> adj(N);
    #pragma omp parallel for schedule(dynamic, 100)
    for (int i = 0; i < (int)N; i++) {
        auto &dst = adj[i];
        for (int t = 0; t < nthreads; t++) {
            auto &src = adj_local[t][i];
            if (!src.empty()) {
                dst.insert(dst.end(), src.begin(), src.end());
                std::vector<int>().swap(src);  // free thread-local memory
            }
        }
    }
    std::vector<std::vector<std::vector<int>>>().swap(adj_local);  // free thread-local memory

    std::cout << "Sorting and deduplicating adjacency lists..." << std::endl;

    // === Optional: sort & deduplicate adjacency lists ===
    int total_edges = 0;
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < (int)N; i++) {
        auto &nbrs = adj[i];
        sort(nbrs.begin(), nbrs.end());
        int old = nbrs.size();
        nbrs.erase(unique(nbrs.begin(), nbrs.end()), nbrs.end());
        assert(nbrs.size() == old);

        #pragma omp atomic
        total_edges += nbrs.size();
    }
    total_edges /= 2;
    
    std::cout << "Built graph with " << total_edges << " edges" << std::endl;
    for (int i = 0; i < 10; i++)
        std::cout << "adj to " << i << ": " << adj[i].size() << "\n";

    // === Compute E_mis ===
    std::map<int, int> labels_freq;
    for (int i = 0; i < (int)N; i++) labels_freq[labels_vec[i]]++;

    for (auto [l, f] : labels_freq) 
        std::cout << "label " << l << " has freq " << f << std::endl;

    std::atomic_llong false_negatives = 0, false_positives = 0;
    for (auto [_, f] : labels_freq) false_negatives += f * (f - 1LL) / 2;

    #pragma omp parallel for schedule(dynamic)
    for (int u = 0; u < (int)N; u++) {
        int64_t fn = 0, fp = 0;
        for (int v : adj[u])
            if (labels_vec[u] == labels_vec[v])
                fn--;
            else
                fp++;

        false_negatives += fn;
        false_positives += fp;
    }

    int64_t E_mis = false_negatives + false_positives;
    std::cout << "Found E_mis = " << E_mis << " = " << 100. * E_mis / total_edges << "% of edges" << std::endl;
    std::cout << "false negatives = " << false_negatives << " = " << 100. * false_negatives / total_edges << "% of edges" << std::endl;
    std::cout << "false positives = " << false_positives << " = " << 100. * false_positives / total_edges << "% of edges" << std::endl;

    return adj;

    std::vector<std::atomic_llong> tris(N);

    count = 0;
    int step_size = N / 1000;
    #pragma omp parallel for schedule(dynamic, 1)
    for (int u = 0; u < (int)N; u++) {
        std::vector<int> result(adj[u].size());

        for (int v : adj[u]) if (adj[v].size() < adj[u].size() || (adj[v].size() == adj[u].size() && v < u)) {
            auto it = std::set_intersection(adj[u].begin(), adj[u].end(), adj[v].begin(), adj[v].end(), result.begin());
            int k = it - result.begin();

            for (int i = 0; i < k; i++) {
                int w = result[i];

                if (adj[w].size() < adj[v].size() || (adj[w].size() == adj[v].size() && w < v)) {
                    tris[u]++;
                    tris[v]++;
                    tris[w]++;
                }
            }
        }

        if (++count % step_size == 0) {
            std::cout << 100. * count / N << "% done...\r";
            std::cout.flush();
        }
    }

    double clustering_coeff = 0;

    for (int u = 0; u < (int)N; u++) {
        if (adj[u].size() > 1)
            clustering_coeff += 2.0 * tris[u] / (adj[u].size() * (adj[u].size() - 1LL));
    }

    clustering_coeff /= N;

    std::cout << "avg clustering coefficient = " << clustering_coeff << std::endl;

    return adj;
}

int main(int argc, char** argv) {
    if (argc < 3) {
        std::cerr << argc << std::endl;
        std::cerr << "Usage: " << argv[0] << " embeddings.npy eps" << std::endl;
        return 1;
    }

    std::string npz_path = argv[1];
    float eps = std::stof(argv[2]);

    auto adj = build_epsilon_graph_here(npz_path, eps);

    using namespace std;

    stringstream eps_ss;
    eps_ss << setprecision(3) << fixed << eps;
    string out_filename = npz_path.substr(0, npz_path.size() - 4) + "_" + eps_ss.str() + ".bin";
    cout << "serializing to " << out_filename << "..." << endl;
    serialize(adj, out_filename);
    cout << "done!" << endl;
    
    return 0;
}
