#include "index_dsg.h"
#include "parameter.h"
#include <omp.h>
#include <random>
#include <chrono>

void read_query(std::string query_file, std::vector<std::vector<float>> &query_data, size_t query_num, size_t dim) {
  std::ifstream in(query_file, std::ios::binary);
  if (!in.is_open()) {
    std::cerr << "Cannot open query file: " << query_file << std::endl;
    exit(1);
  }
  query_data.resize(query_num, std::vector<float>(dim));
  for (size_t i = 0; i < query_num; i++) {
    in.read((char *) query_data[i].data(), dim * sizeof(float));
  }
  in.close();
}

void save_results(std::string result_file, std::vector<std::vector<size_t>> &kmips) {
  std::ofstream out(result_file);
  if (!out.is_open()) {
    std::cerr << "Cannot open result file: " << result_file << std::endl;
    exit(1);
  }
  for (size_t i = 0; i < kmips.size(); i++) {
    for (size_t j = 0; j < kmips[i].size(); j++) {
      if (j != kmips[i].size() - 1) {
        out << kmips[i][j] << " ";
      } else {
        out << kmips[i][j] << "\n";
      }
    }
  }
}

int main(int argc, char **argv) {

  std::string mode = argv[1];
  if (argc < 2) {
    std::cout << "Usage: " << argv[0] << " index data_file init_graph_file index_file num_elements M dim efConstruction init_number threshold" << std::endl;
    std::cout << "Usage: " << argv[0] << " query data_file query_file index_file num_elements dim efs" << std::endl;
    return 1;
  }
  if (mode == "index") {
    std::string data_file(argv[2]);
    std::string init_graph_file(argv[3]);
    std::string index_file(argv[4]);
    size_t num_elements = std::stoi(argv[5]);
    size_t M = std::stoi(argv[6]);
    size_t dim = std::stoi(argv[7]);
    size_t efConstruction = std::stoi(argv[8]);
    size_t init_number = std::stoi(argv[9]);
    size_t threshold = std::stoi(argv[10]);
    size_t angle = std::stoi(argv[11]);
    dsglib::InnerProductSpace space(dim);
    dsglib::L2Space l2_space(dim);
    dsglib::Parameters params;
    params.Set("threshold", threshold);
    params.Set("efConstruction", efConstruction);
    params.Set("A", angle);
    dsglib::DSG *alg_dsg = new dsglib::DSG(&space, &l2_space, num_elements, M, efConstruction, init_number);
    alg_dsg->load_data(data_file.c_str());
    alg_dsg->loadInitialGraph(init_graph_file.c_str());
    auto s_inserts = std::chrono::high_resolution_clock::now();

    omp_set_num_threads(48);
    #pragma omp parallel
    {
      #pragma omp for schedule(dynamic, 64)
      for (size_t i = 0; i < num_elements; i++) {
        alg_dsg->addPoints(i, params);
      }
    }
    auto e_inserts = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff_inserts = e_inserts - s_inserts;
    alg_dsg->saveIndex(index_file.c_str());
    alg_dsg->indexFileSize();
    alg_dsg->indexDegree();
    alg_dsg->checkConnectivity();
    alg_dsg->checkComponent();
    std::cout << "Inserts time: " << diff_inserts.count() << std::endl;
  }
  else{
    std::string data_file(argv[2]);
    std::string query_file(argv[3]);
    std::string index_file(argv[4]);
    size_t num_elements = std::stoi(argv[5]);
    size_t query_num = std::stoi(argv[6]);
    size_t top_k = std::stoi(argv[7]);
    size_t dim = std::stoi(argv[8]);
    size_t efs = std::stoi(argv[9]);
    std::string result_file(argv[10]);
    dsglib::InnerProductSpace space(dim);
    dsglib::DSG *alg_dsg = new dsglib::DSG(&space, index_file.c_str());
    std::cout << "Index loaded" << std::endl;
    alg_dsg->load_data(data_file.c_str());
    std::cout << "Data loaded" << std::endl;
    alg_dsg->setEF(efs);
    std::vector<std::vector<float>> query_data;
    std::cout << "query num: " << query_num << std::endl;
    std::cout << "dim: " << dim << std::endl;
    read_query(query_file, query_data, query_num, dim);
    std::cout << "Query num: " << query_num << std::endl;
    std::cout << query_data[0].size() << std::endl;
    std::cout << query_data.size() << std::endl;
    std::vector<std::vector<size_t>> kmips(query_num);
    auto s_query = std::chrono::high_resolution_clock::now();
      #pragma omp parallel for
      for (size_t i = 0; i < (int)query_num; i++) {
        auto result = alg_dsg->searchMIP(query_data[i].data(), top_k);
        while (!result.empty()) {
          auto top = result.top();
          kmips[i].push_back(top.second);
          result.pop();
        }
        std::reverse(kmips[i].begin(), kmips[i].end());

      }
    
    auto e_query = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff_query = e_query - s_query;
    std::cout << "Average query time: " << std::chrono::duration_cast<std::chrono::milliseconds>(e_query - s_query).count() / (double) query_num << "ms" << std::endl;

    alg_dsg->statistics(query_num);
    save_results(result_file, kmips);
  }

  return 0;
}