#pragma once

#include "parameter.h"
#include "visited_list_pool.h"
#include "distance.h"
#include <sstream>
#include <fstream>
#include <mutex>
#include <atomic>
#include <random>
#include <unordered_set>
#include <list>
#include <memory>
#include <bitset>
#include <boost/dynamic_bitset.hpp>

namespace dsglib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;
typedef size_t labeltype;

constexpr double kPi = 3.14159265358979323846264;


class DSG {
  public:
    struct Neighbor {
      tableint id;
      float distance;
      bool flag;
      Neighbor(tableint id, float distance, bool flag) : id(id), distance(distance), flag(flag) {}

      bool operator<(const Neighbor &other) const {
        return distance > other.distance;
      }
    };

    struct SimpleNeighbor {
      tableint id;
      float distance;
      SimpleNeighbor() = default;
      SimpleNeighbor(tableint id, float distance) : id(id), distance(distance){}

      bool operator<(const SimpleNeighbor &other) const {
        return distance < other.distance;
      }
    };

    static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
    
    size_t max_elements_{0};
    mutable std::atomic<size_t> cur_element_count_{0};
    size_t size_data_per_element_{0};
    size_t size_links_per_element_{0};
    size_t M_{0};
    size_t half_M_{0};
    size_t bi_M_{0};
    size_t maxM_{0};
    size_t ef_construction_{0};
    size_t efs_{0};
    size_t init_number_{0};
    std::vector<std::mutex> locks_;
    std::vector<std::vector<unsigned int>> init_graph_;
    std::vector<std::vector<unsigned >> final_graph_;
    std::vector<std::vector<unsigned>> bi_directional_edges_;
    std::vector<std::vector<unsigned>> prune_graph_;
    
    std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};

    mutable std::vector<std::mutex> label_op_locks_;

    std::mutex global;
    std::vector<std::mutex> link_list_locks_;

    std::vector<bool> is_out_dominantor_;
    std::vector<bool> is_self_dominantor_;
    std::vector<tableint> enterpoint_node_lists_;
    std::vector<float> norms_;
    size_t enterpoint_node_list_size_{50};
    

    size_t size_links_{0};
    size_t offset_data_{0}, offset_{0}, label_offset_{0};

    char *data_memory_{nullptr};
    char **linksLists_{nullptr};
    char *pure_data_memory_{nullptr};
    char *norm_pure_data_memory_{nullptr};
    SimpleNeighbor *cut_graph_{nullptr};

    size_t data_size_{0};

    bool collect_metrics{true};

    DISTFUNC<float> distfunc_;
    DISTFUNC<float> fstdistfunc_;
    void *dist_func_param_{nullptr};    

    mutable std::mutex label_lookup_lock;  // lock for label_lookup_
    std::unordered_map<labeltype, tableint> label_lookup_;

    std::default_random_engine random_generator_;

    mutable std::atomic<long> metric_distance_computations{0};
    mutable std::atomic<long> metric_hops{0};    

    DSG(SpaceInterface<float> *s) {

    }    

    DSG(
        SpaceInterface<float> *s,
        SpaceInterface<float> *l2s,
        size_t max_elements,
        size_t M = 32,
        size_t ef_construction = 200,
        size_t init_number = 70,
        size_t random_seed = 100)
        : label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
          locks_(max_elements_),
          link_list_locks_(max_elements){
        
        max_elements_ = max_elements;
        data_size_ = s->get_data_size();
        distfunc_ = s->get_dist_func();
        dist_func_param_ = s->get_dist_func_param();  

        fstdistfunc_ = l2s->get_dist_func();


        if ( M <= 10000 ) {
            M_ = M;
        } else {
            std::cout << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
            std::cout << "         Cap to 10000 will be applied for the rest of the processing." << std::endl;
            M_ = 10000;
        }            
        half_M_ = 20;
        bi_M_ = 0;
        maxM_ = M_ - half_M_ - bi_M_;
        
        ef_construction_ = std::max(ef_construction, M_);
        efs_ = 100;
        init_number_ = init_number;
        init_graph_.resize(max_elements_);
        bi_directional_edges_.resize(max_elements_);
        prune_graph_.resize(max_elements_);
        SimpleNeighbor *cut_graph_ = new SimpleNeighbor[max_elements_ * (size_t)half_M_];
        for (size_t i = 0; i < max_elements_; i++) {
            init_graph_[i].resize(init_number_);
        }
        random_generator_.seed(random_seed);
        size_links_ = M_ * sizeof(tableint) + sizeof(linklistsizeint);
        size_data_per_element_ = size_links_ + data_size_ + sizeof(labeltype);
        offset_data_ = size_links_;
        label_offset_ = size_links_ + data_size_;
        offset_ = 0;
        pure_data_memory_ = (char *)malloc(max_elements_ * data_size_);
        norm_pure_data_memory_ = (char *)malloc(max_elements_ * data_size_);
        norms_.resize(max_elements_);
        is_out_dominantor_.resize(max_elements_, false);
        is_self_dominantor_.resize(max_elements_, false);
        data_memory_ = (char *)malloc(max_elements_ * size_data_per_element_);
        if (data_memory_ == nullptr) {
            throw std::runtime_error("Not enough memory: DSG failed to allocate data memory.");
        }

        cur_element_count_ = 0;

        visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));

        enterpoint_node_lists_.resize(enterpoint_node_list_size_);
        for (size_t i = 0; i < enterpoint_node_list_size_; i++) {
            enterpoint_node_lists_[i] = random_generator_() % max_elements_;
        }

        linksLists_ = (char **)malloc(max_elements_ * sizeof(void *));
        if (linksLists_ == nullptr) {
            throw std::runtime_error("Not enough memory: DSG failed to allocate linksLists memory.");
        }
        size_links_per_element_ = M_ * sizeof(tableint) + sizeof(linklistsizeint);
    }

    DSG(InnerProductSpace *s, const std::string &location) {
      loadIndex(location, s);
    }

    void clear() {
      free(data_memory_);
      data_memory_ = nullptr;
      for (auto i = 0; i < max_elements_; i++) {
        free(linksLists_[i]);
      }
      free(linksLists_);
      linksLists_ = nullptr;
      visited_list_pool_.reset(nullptr);
      free(pure_data_memory_);
      pure_data_memory_ = nullptr;
    }

    ~DSG() {
      clear();
    }

    struct CompareByFirst {
      constexpr bool operator()(std::pair<float,tableint> const & a,
                                std::pair<float,tableint> const & b) const noexcept {
        return a.first < b.first;
      }
    };




    void setEF(size_t efs) {
      efs_ = efs;
    }

    inline std::mutex& getLabelOpMutex(labeltype label) {
      return label_op_locks_[label % (MAX_LABEL_OPERATION_LOCKS -1)];
    }

    inline labeltype getExternalLabel(tableint internal_id) {
        labeltype return_label;
        memcpy(&return_label, (data_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
        return return_label;
    }

    inline void setExternalLabel(tableint internal_id, labeltype label) const {
        memcpy((data_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
    }

    inline labeltype *getExternalLabeLp(tableint internal_id) const {
        return (labeltype *) (data_memory_ + internal_id * size_data_per_element_ + label_offset_);
    }

    inline char *getDataByInternalId(tableint internal_id) const {
        return (data_memory_ + internal_id * size_data_per_element_ + offset_data_);
    }

    inline linklistsizeint *getLinkListCount(tableint internal_id) const {
        return (linklistsizeint *) (data_memory_ + internal_id * size_data_per_element_);
    }

    unsigned short int getListCount(linklistsizeint *ptr) const {
      return *((unsigned short int *)ptr);
    }
    
    void setListCount(linklistsizeint *ptr, unsigned short int size) const {
      *((unsigned short int *)(ptr)) = *((unsigned short int *)&size);
    }

    linklistsizeint *get_linklist(tableint internal_id) const {
      return (linklistsizeint *)(data_memory_ +
                                internal_id * size_data_per_element_);
    }

    size_t getMaxElements() {
        return max_elements_;
    }

    size_t getCurrentElementCount() {
        return cur_element_count_;
    }    

    void load_data(const char* filename) {
      std::ifstream in(filename, std::ios::binary);
      if (!in.is_open()) {
        std::cerr << "Open file error" << std::endl;
        exit(-1);
      }
      auto dim = *((size_t *) dist_func_param_);
      in.seekg(0, std::ios::beg);
      for (size_t i = 0; i < max_elements_; i++) {
        in.read((char*)(pure_data_memory_ + i * dim * sizeof(float)), dim * sizeof(float)); // check one more time
      }
      in.close();

      for (size_t i = 0; i < max_elements_; i++) {
          float norm = 0.0f;
          for (size_t j = 0; j < dim; j++) {
              float value = *reinterpret_cast<float*>(pure_data_memory_ + (i * dim + j) * sizeof(float));
              norm += value * value;
          }

          norm = std::sqrt(norm);
          norms_[i] = norm;

          for (size_t j = 0; j < dim; j++) {
              float value = *reinterpret_cast<float*>(pure_data_memory_ + (i * dim + j) * sizeof(float));
              *reinterpret_cast<float*>(norm_pure_data_memory_ + (i * dim + j) * sizeof(float)) = value / norm;
          }
      }
    }

    void loadInitialGraph(const char* file_name) {
      std::ifstream in(file_name, std::ios::binary);
      if (!in.is_open()) {
        throw std::runtime_error("Cannot open file " + std::string(file_name));
      }
      unsigned k;
      in.read((char *)&k, sizeof(unsigned));
      std::cout << k << std::endl;
      in.seekg(0, std::ios::end);
      std::ios::pos_type ss = in.tellg();
      size_t fsize = (size_t)ss;
      size_t num = (unsigned)(fsize / (k + 1) / 4);
      in.seekg(0, std::ios::beg);
      std::cout << num << std::endl;
      init_graph_.resize(num);
      init_graph_.reserve(num);
      for (size_t i = 0; i < num; i++) {
        in.seekg(4, std::ios::cur);
        init_graph_[i].resize(k);
        cur_element_count_++;
        in.read((char *)init_graph_[i].data(), k * sizeof(unsigned));
      }
      in.close();
    }

    void loadNSG(const char* file_name) {
      std::ifstream in(file_name, std::ios::binary);
      unsigned width;
      unsigned ep;
      in.read((char *)&width, sizeof(unsigned));
      in.read((char *)&ep, sizeof(unsigned));
      unsigned cc = 0;
      while (!in.eof()) {
        unsigned k;
        in.read((char *)&k, sizeof(unsigned));
        if (in.eof()) break;
        std::vector<unsigned> tmp(k);
        in.read((char *)tmp.data(), k * sizeof(unsigned));
        final_graph_.push_back(tmp);
      }
      in.close();
    }
    
    std::vector<SimpleNeighbor> SSGraphPrune(tableint cur_point, std::vector<SimpleNeighbor> &pool, float threshold) {
      unsigned start = 0;
      std::sort(pool.begin(), pool.end());
      std::vector<SimpleNeighbor> result;
      if (pool[start].id == cur_point) start++;

      result.push_back(pool[start]);

      while (result.size() < half_M_ && (++start) < pool.size()) {
        auto &p = pool[start];
        bool occlude = false;
        auto p_norm = norms_[p.id];
        for (unsigned t = 0; t < result.size(); t++) {
          if (p.id == result[t].id) {
            occlude = true;
            break;
          }
          float djk = fstdistfunc_(norm_pure_data_memory_ + p.id * data_size_, norm_pure_data_memory_ + result[t].id * data_size_, dist_func_param_);
          float cos_ij = (p.distance + result[t].distance - djk) / 2 /
                        sqrt(p.distance * result[t].distance);
          if (cos_ij > threshold) {
            occlude = true;
            break;
          }
        }
        if (!occlude) result.push_back(p);
      }
      return result;
 
    }

    void SSGraphConnect(tableint cur_point, float cos_threshold) {
      SimpleNeighbor *src_pool = cut_graph_ + (size_t)cur_point * (size_t)half_M_;

      for (size_t i = 0; i < half_M_; i++) {
        if (src_pool[i].distance == -1) break;

        SimpleNeighbor sn(cur_point, src_pool[i].distance);
        size_t des = src_pool[i].id;
        SimpleNeighbor *des_pool = cut_graph_ + des * (size_t)half_M_;

        std::vector<SimpleNeighbor> temp_pool;
        int dup = 0;
        {
          std::lock_guard<std::mutex> guard(locks_[des]);
          for (size_t j = 0; j < half_M_; j++) {
            if (des_pool[j].distance == -1) break;
            if (cur_point == des_pool[j].id) {
              dup = 1;
              break;
            }
            temp_pool.push_back(des_pool[j]);
          }
        }
        if (dup) continue;

        temp_pool.push_back(sn);

        if (temp_pool.size() > half_M_) {
          std::vector<SimpleNeighbor> result;
          unsigned start = 0;
          std::sort(temp_pool.begin(), temp_pool.end());
          result.push_back(temp_pool[start]);
          while (result.size() < half_M_ && (++start) < temp_pool.size()) {
            auto &p = temp_pool[start];
            bool occlude = false;
            for (unsigned t = 0; t < result.size(); t++) {
              if (p.id == result[t].id) {
                occlude = true;
                break;
              }
              float djk = fstdistfunc_(norm_pure_data_memory_ + p.id * data_size_, norm_pure_data_memory_ + result[t].id * data_size_, dist_func_param_);
              float cos_ij = (p.distance + result[t].distance - djk) / 2 /
                            sqrt(p.distance * result[t].distance);
              if (cos_ij > cos_threshold) {
                occlude = true;
                break;
              }
            }
            if (!occlude) result.push_back(p);
          }
          {
            std::lock_guard<std::mutex> guard(locks_[des]);
            for (unsigned t = 0; t < result.size(); t++) {
              des_pool[t] = result[t];
            }
            if (result.size() < half_M_) {
              des_pool[result.size()].distance = -1;
            }
          }
        } else {
          std::lock_guard<std::mutex> guard(locks_[des]);
          for (unsigned t = 0; t < half_M_; t++) {
            if (des_pool[t].distance == -1) {
              des_pool[t] = sn;
              if (t + 1 < half_M_) des_pool[t + 1].distance = -1;
              break;
            }
          }
        }
      }
    }

    void getSimpleNeigbors(tableint cur_point, std::vector<SimpleNeighbor> &pool, size_t nn_range) {
      boost::dynamic_bitset<> flags(max_elements_, 0);
      // std::cout << "hh" << std::endl;
      for (auto i = 0; i < init_graph_[cur_point].size(); i++) {
        tableint nid = init_graph_[cur_point][i];
        auto dist = fstdistfunc_(norm_pure_data_memory_ + cur_point * data_size_, norm_pure_data_memory_ + nid * data_size_, dist_func_param_);
        if (flags[nid]) {
          continue;
        } 
        flags[nid] = true;
        pool.emplace_back(nid, dist);
        if (pool.size() >= nn_range) {
          return;
        }
      }

      for (auto i = 0; i < init_graph_[cur_point].size(); i++) {
        tableint nid = init_graph_[cur_point][i];
        for (auto j = 0; j < init_graph_[nid].size(); j++) {
          tableint nid2 = init_graph_[nid][j];
          if (flags[nid2]) {
            continue;
          }
          auto dist = fstdistfunc_(norm_pure_data_memory_ + cur_point * data_size_, norm_pure_data_memory_ + nid2 * data_size_, dist_func_param_);
          flags[nid2] = true;
          pool.emplace_back(nid2, dist);
          if (pool.size() >= nn_range) {
            return;
          }
        }
      }
    }

    void getNeighbors(tableint cur_point, std::vector<Neighbor> &pool, size_t range) {
      boost::dynamic_bitset<> flags(max_elements_, 0);
      for (auto i = 0; i < init_graph_[cur_point].size(); i++) {
        tableint nid = init_graph_[cur_point][i];
        auto ip = -distfunc_(pure_data_memory_ + cur_point * data_size_, pure_data_memory_ + nid * data_size_, dist_func_param_);
        if (flags[nid]) {
          continue;
        } 
        flags[nid] = true;
        pool.emplace_back(nid, ip, true);
        if (pool.size() >= range) {
          return;
        }
      }

      for (auto i = 0; i < init_graph_[cur_point].size(); i++) {
        tableint nid = init_graph_[cur_point][i];
        for (auto j = 0; j < init_graph_[nid].size(); j++) {
          tableint nid2 = init_graph_[nid][j];
          if (flags[nid2]) {
            continue;
          }
          auto ip = -distfunc_(pure_data_memory_ + cur_point * data_size_, pure_data_memory_ + nid2 * data_size_, dist_func_param_);
          flags[nid2] = true;
          pool.emplace_back(nid2, ip, true);
          if (pool.size() >= range) {
            return;
          }
        }
      }
    }

    std::vector<Neighbor> pruneEdge(tableint cur_point, std::vector<Neighbor> &pool, const unsigned threshold) {
      tableint start = 0;
      boost::dynamic_bitset<> flags(max_elements_, 0);
      std::sort(pool.begin(), pool.end());
      std::vector<Neighbor> result;
      std::unordered_map<tableint,float> self_dist_map;
      tableint real_m = 0;
      auto cur_ip = -distfunc_(pure_data_memory_ + cur_point * data_size_, pure_data_memory_ + cur_point * data_size_, dist_func_param_);
      if (cur_ip >= pool[start].distance) {
        is_self_dominantor_[cur_point] = true;
      }
      while (start < pool.size() && real_m < threshold) {
        if (pool[start].id == cur_point) {
          start++;
          continue;
        }
        is_out_dominantor_[pool[start].id] = true; // just test function
        result.push_back(pool[start]);
        real_m++;
        auto ip_self = -distfunc_(pure_data_memory_ + pool[start].id * data_size_, pure_data_memory_ + pool[start].id * data_size_, dist_func_param_);
        self_dist_map[pool[start].id] = ip_self;  
        start++;      
      }

      while (real_m < maxM_ && (++start) < pool.size()) {
        if(pool[start].id == cur_point) {
          continue;
        }
        bool occlude = false;
        auto &p = pool[start];
        auto ip_self = -distfunc_(pure_data_memory_ + p.id * data_size_, pure_data_memory_ + p.id * data_size_, dist_func_param_);
        self_dist_map[p.id] = ip_self;
        for (auto i = 0; i < result.size(); i++) {
          auto nid = result[i].id;
          if (flags[nid]) {
            continue;
          }
          auto ip = -distfunc_(pure_data_memory_ + p.id * data_size_, pure_data_memory_ + nid * data_size_, dist_func_param_);

          if (ip_self < ip) {
            occlude = true;
            break;
          }
          if (self_dist_map[nid] < ip && real_m > threshold) {
            flags[nid] = true;
            real_m--;
          }
        }
        if (!occlude) {
          result.push_back(p);
          real_m++;
        }
      }
      
      std::vector<Neighbor> prune_result;
      for (auto i = 0; i < result.size(); i++) {
        if (prune_result.size() >= maxM_) {
          break;
        }
        if (flags[result[i].id]) {
          continue;
        }
        prune_result.push_back(result[i]);
      }
      return prune_result;
    }


    void checkComponent() {
      std::vector<bool> visited(max_elements_, false);
      std::vector<int> component_size;
      size_t component_num = 0;
      for (auto i = 0; i < enterpoint_node_list_size_; i++) {
        if (visited[enterpoint_node_lists_[i]]) {
          continue;
        }
        tableint curr_id = enterpoint_node_lists_[i];
        if (curr_id >= max_elements_) {
          continue;
        }
        std::queue<tableint> q;
        q.push(curr_id);
        visited[curr_id] = true;
        int cur_size = 0;
        while (!q.empty()) {
          auto cur = q.front();
          q.pop();
          cur_size++;
          auto ll_cur = get_linklist(cur);
          auto data = (tableint *)(ll_cur + 1);
          for (auto i = 0; i < getListCount(ll_cur); i++) {
            if (!visited[data[i]]) {
              visited[data[i]] = true;
              q.push(data[i]);
            }
          }
        }
        component_size.push_back(cur_size);
      }
      for (auto i = 0; i < max_elements_; i++) {
        if (visited[i]) {
          continue;
        }
        int cur_size = 0;
        component_num++;
        std::queue<tableint> q;
        q.push(i);
        visited[i] = true;
        while (!q.empty()) {
          auto cur = q.front();
          q.pop();
          cur_size++;
          auto ll_cur = get_linklist(cur);
          auto data = (tableint *)(ll_cur + 1);
          for (auto i = 0; i < getListCount(ll_cur); i++) {
            if (!visited[data[i]]) {
              visited[data[i]] = true;
              q.push(data[i]);
            }
          }
        }
        component_size.push_back(cur_size);
      }
      int max_component_size = 0;
      int min_component_size = max_elements_;
      for (auto i = 0; i < component_size.size(); i++) {
        max_component_size = std::max(max_component_size, component_size[i]);
        min_component_size = std::min(min_component_size, component_size[i]);
      }
      std::cout << "Component number: " << component_size.size() << std::endl;
      std::cout << "Avg. Component size: " << max_elements_ / component_size.size() << std::endl;
      std::cout << "Max. Component size: " << max_component_size << std::endl;
      std::cout << "Min. Component size: " << min_component_size << std::endl;
    }

    void spanningTree() {
      boost::dynamic_bitset<> flags(max_elements_, 0);
      auto root_id = enterpoint_node_lists_[0];
      std::vector<tableint> uncheck_set(1);
      std::queue<tableint> spanning_tree;
      spanning_tree.push(root_id);
      while(uncheck_set.size() > 0){
        while(!spanning_tree.empty()){
          unsigned q_front=spanning_tree.front();
          spanning_tree.pop();
          auto ll_cur = get_linklist(q_front);
          auto data = (tableint *)(ll_cur + 1);
          auto neighbor_szie = getListCount(ll_cur);
          for (auto i = 0; i < neighbor_szie; i++) {
            if (!flags[data[i]]) {
              flags[data[i]] = true;
              spanning_tree.push(data[i]);
            }
          }
        }
        uncheck_set.clear();
        for (auto i = 0; i < max_elements_; i++) {
          if (!flags[i]) {
            uncheck_set.push_back(i);
          }
        }
        std::cout << "uncheck_set size:" << uncheck_set.size() << std::endl;
        if (uncheck_set.size() > 0) {
          auto span_id = uncheck_set[0]; 
          auto ll_cur = get_linklist(span_id);
          auto data = (tableint *)(ll_cur + 1);
          auto neighbor_szie = getListCount(ll_cur);
          bool add_flag = false;
          for (auto i = 0; i < neighbor_szie; i++) {
            if (flags[data[i]]) {
              auto span_ll_cur = get_linklist(data[i]);
              auto cur_list_count = getListCount(span_ll_cur);
              if (cur_list_count < M_) {
                auto data = (tableint *)(span_ll_cur + 1);
                data[cur_list_count] = span_id;
                setListCount(span_ll_cur, cur_list_count + 1);
                add_flag = true;
              }
              break;
            }
          }
          if (!add_flag) {
            for (auto i = 0; i < max_elements_; i++) {
              if (flags[i]) {
                auto span_ll_cur = get_linklist(i);
                auto cur_list_count = getListCount(span_ll_cur);
                if (cur_list_count < M_) {
                  auto data = (tableint *)(span_ll_cur + 1);
                  data[cur_list_count] = span_id;
                  setListCount(span_ll_cur, cur_list_count + 1);
                  break;
                }
              }
            }
          }
          spanning_tree.push(span_id);
          flags[span_id] = true;
        }
      }      
    }

    void checkConnectivity() {
      std::vector<bool> visited(max_elements_, false);
      std::queue<tableint> q;
      q.push(0);
      visited[0] = true;
      while (!q.empty()) {
        auto cur = q.front();
        q.pop();
        auto ll_cur = get_linklist(cur);
        auto data = (tableint *)(ll_cur + 1);
        for (auto i = 0; i < getListCount(ll_cur); i++) {
          if (!visited[data[i]]) {
            visited[data[i]] = true;
            q.push(data[i]);
          }
        }
      }
      for (auto i = 0; i < max_elements_; i++) {
        if (!visited[i]) {
          std::cout << "Not connected" << std::endl;
          return;
        }
      }
      std::cout << "Connected" << std::endl;
    }


    void MergeNeigbors(tableint cur_point, std::vector<SimpleNeighbor> nn_neighbors, std::vector<Neighbor> ip_neigbors) {
      boost::dynamic_bitset<> flags(max_elements_, 0);
      std::vector<size_t> merge_result;
      
      for (auto i = 0; i < nn_neighbors.size(); i++) {
        merge_result.push_back(nn_neighbors[i].id);
        flags[nn_neighbors[i].id] = true;
      }

      for (auto i = 0; i < ip_neigbors.size(); i++) {
        if (flags[ip_neigbors[i].id]) {
          continue;
        }
        merge_result.push_back(ip_neigbors[i].id);
      }
      // random_shuffle(merge_result.begin(), merge_result.end());
      prune_graph_[cur_point].resize(merge_result.size());
      for (auto i = 0; i < merge_result.size(); i++) {
        prune_graph_[cur_point][i] = merge_result[i];
      }
      auto ll_cur = get_linklist(cur_point);
      setListCount(ll_cur, merge_result.size());
      tableint *data = (tableint *)(ll_cur + 1);
      for (auto i = 0; i < merge_result.size(); i++) {
        data[i] = merge_result[i];
      }
      memcpy(getExternalLabeLp(cur_point), &cur_point, sizeof(labeltype));
      memcpy(getDataByInternalId(cur_point), pure_data_memory_ + cur_point * data_size_, data_size_);  
    }
    
    void Connect(tableint cur_point) {
      boost::dynamic_bitset<> flags(max_elements_, 0);
      std::vector<size_t> merge_result;
      for (auto i = 0; i < prune_graph_[cur_point].size(); i++) {
        merge_result.push_back(prune_graph_[cur_point][i]);
        flags[prune_graph_[cur_point][i]] = true;
      }
      for (auto i = 0; i < bi_directional_edges_[cur_point].size(); i++) {
        if (flags[bi_directional_edges_[cur_point][i]]) {
          continue;
        }
        if (merge_result.size() >= M_) {
          break;
        }
        merge_result.push_back(bi_directional_edges_[cur_point][i]);
      }
      auto ll_cur = get_linklist(cur_point);
      setListCount(ll_cur, merge_result.size());
      tableint *data = (tableint *)(ll_cur + 1);
      for (auto i = 0; i < merge_result.size(); i++) {
        data[i] = merge_result[i];
      }
      memcpy(getExternalLabeLp(cur_point), &cur_point, sizeof(labeltype));
      memcpy(getDataByInternalId(cur_point), pure_data_memory_ + cur_point * data_size_, data_size_);        
    }


    void addPoints(tableint cur_point, const Parameters &parameters) {
      auto threshold = parameters.Get<unsigned>("threshold");
      float angle = parameters.Get<float>("A");
      float cos_threshold = std::cos(angle / 180 * kPi);
      std::vector<Neighbor> pool_ip;
      std::vector<SimpleNeighbor> pool_nn;
      getSimpleNeigbors(cur_point, pool_nn, ef_construction_);
      auto nn_neighbors = SSGraphPrune(cur_point, pool_nn, cos_threshold);
      getNeighbors(cur_point, pool_ip, ef_construction_);
      auto ip_neighbors = pruneEdge(cur_point, pool_ip, threshold);

      MergeNeigbors(cur_point, nn_neighbors, ip_neighbors);
    }

    std::priority_queue<std::pair<float,tableint>, std::vector<std::pair<float,tableint>>, CompareByFirst>
    searchBase(const void *data_point) {
      VisitedList *vl = visited_list_pool_->getFreeVisitedList();
      vl_type *visited_array = vl->mass;
      vl_type visited_array_tag = vl->curV;
      boost::dynamic_bitset<> flags(max_elements_, 0);
      std::priority_queue<std::pair<float, tableint>, std::vector<std::pair<float, tableint>>, CompareByFirst> top_candidates;
      std::priority_queue<std::pair<float, tableint>, std::vector<std::pair<float, tableint>>, CompareByFirst> candidateSet;      

      float lowerBound = std::numeric_limits<float>::min();
      for (auto i = 0; i < enterpoint_node_list_size_; i++) {
        tableint curr_id = enterpoint_node_lists_[i];
        if (curr_id >= max_elements_) {
          continue;
        }
        float dist = distfunc_(data_point, getDataByInternalId(curr_id), dist_func_param_);
        lowerBound = std::max(lowerBound, dist);
        top_candidates.emplace(dist, curr_id);
        candidateSet.emplace(-dist, curr_id);
        visited_array[curr_id] = visited_array_tag;
      }
      
      tableint hops = 0;
      while (!candidateSet.empty()) {
        std::pair<float, tableint> curr_node_pair = candidateSet.top();
        auto candidate_dist = -curr_node_pair.first;
        bool flag_stop_search;
        flag_stop_search = candidate_dist > lowerBound;
        
        if (flag_stop_search) {
          break;
        }

        candidateSet.pop();

        if (flags[curr_node_pair.second]) {
          continue;
        }
        flags[curr_node_pair.second] = true;

        auto curr_node_id = curr_node_pair.second;
        int *data = (int *)get_linklist(curr_node_id);
        size_t neighbor_size = getListCount((linklistsizeint *)data);

        if (collect_metrics) {
          metric_hops++;
        }

#ifdef USE_SSE
      _mm_prefetch((char *)(visited_array + *(data + 1)), _MM_HINT_T0);
      _mm_prefetch((char *)(visited_array + *(data + 1) + 64), _MM_HINT_T0);
      _mm_prefetch(data_memory_ +
                       (*(data + 1)) * size_data_per_element_ + offset_data_,
                   _MM_HINT_T0);
      _mm_prefetch((char *)(data + 2), _MM_HINT_T0);
#endif
      for (size_t j = 1; j <= neighbor_size; j++) {
        int candidate_id = *(data + j);
//                    if (candidate_id == 0) continue;
#ifdef USE_SSE
        _mm_prefetch((char *)(visited_array + *(data + j + 1)), _MM_HINT_T0);
        _mm_prefetch(data_memory_ +
                         (*(data + j + 1)) * size_data_per_element_ +
                         offset_data_,
                     _MM_HINT_T0);  ////////////
#endif        
        if (visited_array[candidate_id] == visited_array_tag) {
          continue;
        }
        if (flags[candidate_id]) {
          continue;
        }
        metric_distance_computations++;
        visited_array[candidate_id] = visited_array_tag;
        float dist = distfunc_(data_point, getDataByInternalId(candidate_id), dist_func_param_);
        candidateSet.emplace(-dist, candidate_id);
          
#ifdef USE_SSE
        _mm_prefetch(
            data_memory_ +
                candidateSet.top().second * size_data_per_element_ +
                offset_,  ///////////
            _MM_HINT_T0);       ////////////////////////
#endif
        if (top_candidates.top().first > dist || top_candidates.size() < efs_) {
          top_candidates.emplace(dist, candidate_id);
          if (top_candidates.size() > efs_) {
            top_candidates.pop();
          }
          
          lowerBound = top_candidates.top().first;
        }
      }
    }
      visited_list_pool_->releaseVisitedList(vl);
      return top_candidates;
    }   

    std::priority_queue<std::pair<float, tableint>>
    searchMIP(const void *query_data, size_t k){
      std::priority_queue<std::pair<float, tableint>> result;
      std::priority_queue<std::pair<float,tableint>, std::vector<std::pair<float,tableint>>, CompareByFirst> top_candidates;
      top_candidates = searchBase(query_data);
      while (top_candidates.size() > k) {
        top_candidates.pop();
      }
      while (top_candidates.size() > 0) {
        auto res = top_candidates.top();
        result.push(res);
        top_candidates.pop();
      }
      return result;
    }

    void getNeighborByHeuristic() {
      // TODO
    }

    void MutuallyConnect() {
      // TODO
    }

    void indexFileSize() const {
      size_t size = 0;
      size += sizeof(offset_);
      size += sizeof(max_elements_);
      size += sizeof(size_data_per_element_);
      size += sizeof(label_offset_);
      size += sizeof(offset_data_);
      size += enterpoint_node_list_size_ * sizeof(tableint);
      size += sizeof(maxM_);
      size += sizeof(M_);
      size += sizeof(ef_construction_);
      for (size_t i = 0; i < max_elements_; i++) {
        int *data = (int *)get_linklist(i);
        size_t neighbor_size = getListCount((linklistsizeint *)data);
        size += neighbor_size * sizeof(tableint) + sizeof(linklistsizeint) + sizeof(labeltype);
      }
      double sizeInMB = static_cast<double>(size) / (1024 * 1024);
      std::cout << "Index Size: " << sizeInMB << "MB" << std::endl;
    }

    void indexDegree() {
      size_t max_degree = 0;
      size_t min_degree = 1000000;
      size_t sum_degree = 0;
      for (size_t i = 0; i < max_elements_; i++) {
        int *data = (int *)get_linklist(i);
        size_t neighbor_size = getListCount((linklistsizeint *)data);
        max_degree = std::max(max_degree, neighbor_size);
        min_degree = std::min(min_degree, neighbor_size);
        sum_degree += neighbor_size;
      }
      std::cout << "Max degree: " << max_degree << std::endl;
      std::cout << "Min degree: " << min_degree << std::endl;
      std::cout << "Average degree: " << sum_degree / max_elements_ << std::endl;
    }

    void statistics(size_t query_num) {
      std::cout << "Metric distance computations: " << metric_distance_computations / query_num << std::endl;
      std::cout << "Metric hops: " << metric_hops / query_num << std::endl;
    }
    
    template <typename T>
    static void writeBinaryPOD(std::ostream &out, const T &podRef) {
      out.write((char *)&podRef, sizeof(T));
    }

    template <typename T>
    static void readBinaryPOD(std::istream &in, T &podRef) {
      in.read((char *)&podRef, sizeof(T));
    }

    void saveIndex(const std::string &location) {
      std::ofstream output(location, std::ios::binary);
      std::streampos position;
      
    writeBinaryPOD(output, offset_);
    writeBinaryPOD(output, max_elements_);
    writeBinaryPOD(output, size_data_per_element_);
    writeBinaryPOD(output, label_offset_);
    writeBinaryPOD(output, offset_data_);
    writeBinaryPOD(output, enterpoint_node_list_size_);
    writeBinaryPOD(output, maxM_);
    writeBinaryPOD(output, M_);

    // output.write((char*)enterpoint_node_lists_.data(), enterpoint_node_list_size_ * sizeof(tableint));
    output.write(data_memory_,
                 max_elements_ * size_data_per_element_);

    output.close();

    }

    void loadIndex(const std::string &location, InnerProductSpace *s) {
      std::ifstream input(location, std::ios::binary);
      if (!input.is_open()) {
        throw std::runtime_error("Cannot open file " + location);
      }
      clear();
     
      input.seekg(0, std::ios::end);
      std::streampos position = input.tellg();
      input.seekg(0, std::ios::beg);

      readBinaryPOD(input, offset_);
      readBinaryPOD(input, max_elements_);
      readBinaryPOD(input, size_data_per_element_);
      readBinaryPOD(input, label_offset_);
      readBinaryPOD(input, offset_data_);
      readBinaryPOD(input, enterpoint_node_list_size_);
      readBinaryPOD(input, maxM_);
      readBinaryPOD(input, M_);

      std::cout << "offset: " << offset_ << std::endl;
      std::cout << "max_elements: " << max_elements_ << std::endl;
      std::cout << "size_data_per_element: " << size_data_per_element_ << std::endl;
      std::cout << "label_offset: " << label_offset_ << std::endl;
      std::cout << "offset_data: " << offset_data_ << std::endl;
      std::cout << "enterpoint_node_list_size: " << enterpoint_node_list_size_ << std::endl;
      std::cout << "maxM: " << maxM_ << std::endl;
      std::cout << "M: " << M_ << std::endl;
      enterpoint_node_lists_.resize(enterpoint_node_list_size_);
      norms_.resize(max_elements_);
      random_generator_.seed(100);
      for (size_t i = 0; i < enterpoint_node_list_size_; i++) {
        enterpoint_node_lists_[i] = random_generator_() % max_elements_;
      }

      data_size_ = s->get_data_size();
      distfunc_ = s->get_dist_func();
      dist_func_param_ = s->get_dist_func_param();
      
      pure_data_memory_ = (char *)malloc(max_elements_ * data_size_);
      norm_pure_data_memory_ = (char *)malloc(max_elements_ * data_size_);
      data_memory_ = (char *)malloc(max_elements_ * size_data_per_element_);
      if (data_memory_ == nullptr) {
        throw std::runtime_error("Not enough memory: DSG failed to allocate data memory.");
      }
      input.read(data_memory_, max_elements_ * size_data_per_element_);
      
      visited_list_pool_.reset(new VisitedListPool(1, max_elements_));
      std::vector<std::mutex>(max_elements_).swap(link_list_locks_);
      std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
      
      for (size_t i = 0; i < max_elements_; i++) {
        label_lookup_[getExternalLabel(i)] = i;
      }
      input.close();
    }  


};

};

