#include <random>
#include <iostream>
#include <iomanip>
#include <fstream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <unordered_map>
#include <unordered_set>
#include <map>
#include <cmath>
#include <ctime>
#include <queue>
#include <vector>
#include <omp.h>

#include <limits>
#include <sys/time.h>


#include <set>
#include <algorithm>
#include <ctime>

#include "search_function.h"

using  namespace std;

int main(int argc, char **argv) {

    size_t d, X;
    if (argc == 3) {
        d = atoi(argv[1]);
	X = atoi(argv[2]);
    } else {
        cout << " Need to specify parameters" << endl;
        return 1;
    }

    time_t start, end;
    const size_t n = 1000000;  // number of points in base set
    const size_t n_q = 10000;  // number of points in query set
    const size_t n_tr = 100;
    const size_t kl_size = 15; // KL graph size

    size_t knn_size = 100;
    size_t knn_size_for_beam = 20;

    if (d == 3) {
	if (X == 500000){
            knn_size = 120;
	    knn_size_for_beam = 30;
	} else if (X == 2000){
            knn_size_for_beam = 10;
	} else if (X == 50000){
            knn_size_for_beam = 12;
	} else {
            knn_size = 50;
            knn_size_for_beam = 8;
	}
    } else if (d == 5) {
        knn_size = 80;
        knn_size_for_beam = 10;
    } else if (d == 7){
        knn_size = 180;
        knn_size_for_beam = 16;
    } else if (d == 9) {
        knn_size = 300;
        knn_size_for_beam = 16;
    } else if (d == 17) {
        knn_size = 2000;
    }
 
    Metric *ll2;
    if (X == 0)
        ll2 = new LikeL2Metric();
    else
        ll2 = new HyperbolicMetric();

    cout << "d = " << d << ", kl_size = " << kl_size << ", knn_size = " << knn_size << ", X = " << X << endl;

    std::mt19937 random_gen;
    std::random_device device;
    random_gen.seed(device());

    string exp_name;
    if (X == 0)
	exp_name = "sphere_d" + to_string(d-1);
    else
    	exp_name = "hyperbolic_d" + to_string(d-1) + "_X" +  to_string(X);
    string path_data = "data/" + exp_name + '/';
    string path_models = "models/" + exp_name + '/';
    

    string dir_d = path_data + "base.1M.dvecs";
    std::cout << dir_d << std::endl;
    const char *database_dir = dir_d.c_str();  // path to data
    string dir_q = path_data + "query.10K.dvecs";
    const char *query_dir = dir_q.c_str();  // path to data
    string dir_t = path_data + "gt.10K.ivecs";
    const char *truth_dir = dir_t.c_str();  // path to data


    string dir_knn = path_models + "knn.ivecs";
    const char *edge_knn_dir = dir_knn.c_str();


    string dir_kl = path_models + "kl" + to_string(kl_size) + ".ivecs";
    const char *edge_kl_dir = dir_kl.c_str();


    string output = "results/" + exp_name + ".txt";
    const char *output_txt = output.c_str();

    remove(output_txt);


    bool data_exist = FileExist(dir_d);
    if (data_exist != true) {
        std::cout << "Creating data" << std::endl;
        vector<double> data = create_uniform_data(n_q + n, d, random_gen);
        vector<double> queries;
        for (int i=0; i < n_q*d; ++i) {
            queries.push_back(data[i]);
        }
        vector<double> db;
        for (int i=0; i < n*d; ++i) {
            db.push_back(data[n_q*d + i]);
        }
        vector<uint32_t> truth = get_truth(db, queries, n, d, n_q, ll2);

        std::ofstream data_input_db(database_dir, std::ios::binary);
        writeXvec<double>(data_input_db, db.data(), d, n);

        std::ofstream data_input_q(query_dir, std::ios::binary);
        writeXvec<double>(data_input_q, queries.data(), d, n_q);

        std::ofstream data_input_g(truth_dir, std::ios::binary);
        writeXvec<uint32_t>(data_input_g, truth.data(), n_tr, n_q);
    }


    std::cout << "Loading data from " << database_dir << std::endl;
    std::vector<double> db(n * d);
    {
        std::ifstream data_input(database_dir, std::ios::binary);
	readXvec<double>(data_input, db.data(), d, n);
    }
    
    std::cout << "Loading queries from " << query_dir << std::endl;
    std::vector<double> queries(n_q * d);
    {
        std::ifstream data_input(query_dir, std::ios::binary);
        readXvec<double>(data_input, queries.data(), d, n_q);
    }

    std::cout << "Loading groundtruth from " << truth_dir << std::endl;
    std::vector<uint32_t> truth(n_q * n_tr);
    {
        std::ifstream data_input(truth_dir, std::ios::binary);
        readXvec<uint32_t>(data_input, truth.data(), n_tr, n_q);
    }

//--------------------------------------------------------------------------------------------------------------------------------------------

// BUILD GRAPHS


    bool knn_exist = FileExist(dir_knn);
    if (knn_exist != true) {
	std::cout << "Build kNN graph \n";
        time(&start);
        ExactKNN knn;
        knn.Build(knn_size, db, n, d, ll2);
        cout << knn.matrixNN[0].size() << ' ' << knn.matrixNN[3].size() << endl;
        time(&end);
        cout << difftime(end, start) << endl;
        write_edges(edge_knn_dir, knn.matrixNN);
    }

    vector< vector <uint32_t>> knn(n);
    knn = load_edges(edge_knn_dir, knn);
    cout << "knn " << FindGraphAverageDegree(knn) << endl;
    knn = CutKNNbyK(knn, db, knn_size, n, d, ll2);
    cout << "knn " << FindGraphAverageDegree(knn) << endl;
    vector< vector <uint32_t>> knn_for_beam = CutKNNbyK(knn, db, knn_size_for_beam, n, d, ll2);
    cout << "knn_for_beam " << FindGraphAverageDegree(knn_for_beam) << endl;

    bool kl_exist = FileExist(dir_kl);
    if (kl_exist != true) {
        time(&start);
    	        KLgraph kl_sqrt;
    	        kl_sqrt.BuildByNumberCustom(kl_size, db, n, d, pow(n, 0.5), random_gen, ll2);
		time (&end);
		cout << difftime(end, start) << endl;
        write_edges(edge_kl_dir, kl_sqrt.longmatrixNN);
    }

    vector< vector <uint32_t>> kl(n);
    kl = load_edges(edge_kl_dir, kl);
    cout << "kl " << FindGraphAverageDegree(kl) << endl;


	get_synthetic_tests(n, d, X, n_q, n_tr, random_gen, knn, knn, db, queries, truth, output_txt, ll2, "knn", false, false, false);
	get_synthetic_tests(n, d, X, n_q, n_tr, random_gen, knn, kl, db, queries, truth, output_txt, ll2, "knn_kl", true, true, false);
	if (X == 500000){
	    vector< vector <uint32_t>> knn_for_beam_2 = CutKNNbyK(knn, db, 34, n, d, ll2);
            get_synthetic_tests(n, d, X, n_q, n_tr, random_gen, knn_for_beam_2, knn_for_beam_2, db, queries, truth, output_txt, ll2, "knn_beam", false, false, true);
	} else {	
	    get_synthetic_tests(n, d, X, n_q, n_tr, random_gen, knn_for_beam, knn_for_beam, db, queries, truth, output_txt, ll2, "knn_beam", false, false, true);
	}
	//if (d == 3 && X > 1){
        //    vector< vector <uint32_t>> knn_for_beam_2 = CutKNNbyK(knn, db, 12, n, d, ll2);
        //    get_synthetic_tests(n, d, X, n_q, n_tr, random_gen, knn_for_beam_2, kl, db, queries, truth, output_txt, ll2, "knn_beam_kl", true, true, true);
        //} else {
	    get_synthetic_tests(n, d, X, n_q, n_tr, random_gen, knn_for_beam, kl, db, queries, truth, output_txt, ll2, "knn_beam_kl", true, true, true);
        //}

    return 0;

}
