#include <fstream>
#include <iostream>
#include <vector>
#include <omp.h>
#include <Eigen/Dense>
#include "cnpy.h"

using namespace std;
using namespace Eigen;

std::vector<std::vector<int>> build_epsilon_graph(std::string npz_path, float eps) {
    if (eps <= 0) {
      std::cerr << "Need ε to build similarity graph!" << std::endl;
      throw;
    }

    omp_set_num_threads(96);
    
    std::cout << "Loading " << npz_path << "...\n";
    cnpy::npz_t data = cnpy::npz_load(npz_path);

    cnpy::NpyArray embeddings = data["embeddings"];
    cnpy::NpyArray labels = data["labels"];
    size_t N = embeddings.shape[0];
    size_t D = embeddings.shape[1];
    std::cout << "Loaded embeddings: " << N << " × " << D << std::endl;

    int* labels_arr = labels.data<int>();
    std::vector<int> labels_vec{labels_arr, labels_arr + labels.shape[0]};
    Map<Matrix<float, Dynamic, Dynamic, RowMajor>> X(embeddings.data<float>(), N, D);

    // Normalize for cosine similarity
    X.rowwise().normalize();

    
    int nthreads = omp_get_max_threads();
    const int block_size = 512;  // tune based on cache and memory
    
    std::cout << "Building ε-graph with ε = " << eps << " using " << nthreads << " threads and block size " << block_size << "..." << std::endl;
    
    std::vector<std::vector<std::vector<int>>> adj_local(nthreads, std::vector<std::vector<int>>(N));

    // === Blocked computation ===
    #pragma omp parallel
    {
        int tid = omp_get_thread_num();
        auto &adj_thread = adj_local[tid];

        // loop over block pairs
        #pragma omp for collapse(2) schedule(dynamic, 1)
        for (int bi = 0; bi < (int)N; bi += block_size) {
            for (int bj = 0; bj < (int)N; bj += block_size) {
                if (bi > bj) continue;
                int i_end = std::min<int>(bi + block_size, N);
                auto Xi = X.middleRows(bi, i_end - bi);
                int j_end = std::min<int>(bj + block_size, N);
                auto Xj = X.middleRows(bj, j_end - bj);

                // compute similarity block: (i_end-bi) × (j_end-bj)
                MatrixXf S = Xi * Xj.transpose();

                // scan and store edges above threshold
                for (int i = 0; i < S.rows(); i++) {
                    int u = bi + i;
                    for (int j = (bi == bj ? i + 1 : 0); j < S.cols(); j++) {
                        if (S(i, j) >= eps) {
                            int v = bj + j;
                            adj_thread[u].push_back(v);
                            adj_thread[v].push_back(u);
                        }
                    }
                }
            }
        }
    }

    std::cout << "Merging adjacency lists..." << std::endl;

    // === Merge adjacency lists across threads ===
    std::vector<std::vector<int>> adj(N);
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < (int)N; i++) {
        auto &dst = adj[i];
        for (int t = 0; t < nthreads; t++) {
            auto &src = adj_local[t][i];
            if (!src.empty()) {
                dst.insert(dst.end(), src.begin(), src.end());
                std::vector<int>().swap(src);  // free thread-local memory
            }
        }
    }
    std::vector<std::vector<std::vector<int>>>().swap(adj_local);  // free thread-local memory

    std::cout << "Sorting and deduplicating adjacency lists..." << std::endl;

    // === Optional: sort & deduplicate adjacency lists ===
    int total_edges = 0;
    #pragma omp parallel for schedule(dynamic)
    for (int i = 0; i < (int)N; i++) {
        auto &nbrs = adj[i];
        sort(nbrs.begin(), nbrs.end());
        int old = nbrs.size();
        nbrs.erase(unique(nbrs.begin(), nbrs.end()), nbrs.end());
        assert(nbrs.size() == old);

        #pragma omp atomic
        total_edges += nbrs.size();
    }
    total_edges /= 2;
    
    std::cout << "Built graph with " << total_edges << " edges" << std::endl;
    for (int i = 0; i < 10; i++)
        std::cout << "adj to " << i << ": " << adj[i].size() << "\n";

    return adj;
}

// Serialize vector<vector<int>> to binary file
void serialize(const vector<vector<int>>& data, const string& filename) {
  ofstream ofs(filename, ios::binary);
  size_t outer_size = data.size();
  ofs.write(reinterpret_cast<const char*>(&outer_size), sizeof(outer_size));
  for (const auto& row : data) {
    size_t inner_size = row.size();
    ofs.write(reinterpret_cast<const char*>(&inner_size), sizeof(inner_size));
    if (!row.empty()) {
      ofs.write(reinterpret_cast<const char*>(row.data()),
                inner_size * sizeof(int));
    }
  }
}

// Deserialize vector<vector<int>> from binary file
vector<vector<int>> deserialize_adjacency_list(const string& filename) {
  ifstream ifs(filename, ios::binary);
  size_t outer_size;
  ifs.read(reinterpret_cast<char*>(&outer_size), sizeof(outer_size));
  vector<vector<int>> data(outer_size);
  for (size_t i = 0; i < outer_size; ++i) {
    size_t inner_size;
    ifs.read(reinterpret_cast<char*>(&inner_size), sizeof(inner_size));
    data[i].resize(inner_size);
    if (inner_size > 0) {
      ifs.read(reinterpret_cast<char*>(data[i].data()),
               inner_size * sizeof(int));
    }
  }
  return data;
}

vector<vector<int>> deserialize(const string& filename, float threshold) {
  if (filename.ends_with(".npz"))
    return build_epsilon_graph(filename, threshold);
  else
    return deserialize_adjacency_list(filename);
}