#include "index_ra.h"

#include <omp.h>
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>

#include <bitset>
#include <boost/dynamic_bitset.hpp>
#include <chrono>
#include <ctime>
#include <random>
#include <atomic>
#include <utility>

#include "efanna2e/exceptions.h"
#include "efanna2e/parameters.h"
#define PROJECTION_SLACK 1.5

namespace efanna2e {

IndexRetrAtten::IndexRetrAtten(const size_t dimension, const size_t n, Metric m, Index *initializer)
    : Index(dimension, n, m), initializer_{initializer}, total_pts_const_(n) {
    bipartite_ = true;
    l2_distance_ = new DistanceL2();
    
    width_ = 1;
    if (m == efanna2e::COSINE) {
        std::cout << "Inside using IP distance after normalization." << std::endl;
        need_normalize = true;
    }
}

IndexRetrAtten::~IndexRetrAtten() {
    delete visited_list_pool_;
    delete l2_distance_;
}

void IndexRetrAtten::BuildRAIndex(size_t n_sq, const float *sq_data, size_t n_bp, const float *bp_data,
                                    const Parameters &parameters) {
    auto s = std::chrono::high_resolution_clock::now();
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    data_bp_ = bp_data;
    data_sq_ = sq_data;
    nd_ = n_bp;
    nd_sq_ = n_sq;
    total_pts_ = nd_ + nd_sq_;
    u32_nd_ = static_cast<uint32_t>(nd_);
    u32_nd_sq_ = static_cast<uint32_t>(nd_sq_);
    u32_total_pts_ = static_cast<uint32_t>(total_pts_);
    locks_ = std::vector<std::mutex>(total_pts_);
    SetRetrAttenParameters(parameters);
    if (need_normalize) {
        std::cout << "normalizing base data" << std::endl;
        for (size_t i = 0; i < nd_; ++i) {
            float *data = const_cast<float *>(data_bp_);
            normalize(data + i * dimension_, dimension_);
        }
    }
    float bipartite_degree_avg = 0;
    size_t bipartite_degree_max = 0, bipartite_degree_min = std::numeric_limits<size_t>::max();
    float projection_degree_avg = 0;
    size_t projection_degree_max = 0, projection_degree_min = std::numeric_limits<size_t>::max();
    size_t i = 0;

    supply_nbrs_.resize(nd_);
    RetrAttenProjectionReserveSpace(parameters);

    CalculateProjectionep();

    assert(projection_ep_ < nd_);
    std::cout << "begin link projection" << std::endl;
    LinkProjection(parameters);
    auto e = std::chrono::high_resolution_clock::now();
    auto diff = e - s;
    std::cout << "Build projection graph time: " << diff.count() / (1000 * 1000 * 1000) << std::endl;

    for (i = 0; i < projection_graph_.size(); ++i) {
        std::vector<uint32_t> &nbrs = projection_graph_[i];
        projection_degree_avg += static_cast<float>(nbrs.size());
        projection_degree_max = std::max(projection_degree_max, nbrs.size());
        projection_degree_min = std::min(projection_degree_min, nbrs.size());
    }

    has_built = true;
}

void IndexRetrAtten::SetSQforRAIndex(size_t n_sq, const float *sq_data, size_t n_bp, faiss::IndexScalarQuantizer *sq,
                                const Parameters &parameters) {
    
    auto s = std::chrono::high_resolution_clock::now();
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    scalar_quant_ = sq;
    sqdc_ = scalar_quant_->get_FlatCodesDistanceComputer();
    data_bp_ = nullptr;
    data_sq_ = sq_data;
    has_sq = true;
}
void IndexRetrAtten::BuildRAIndexwithData(size_t n_sq, const float *sq_data, size_t n_bp, const float *bp_data,
                                    const Parameters &parameters) {
    
    auto s = std::chrono::high_resolution_clock::now();
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    float *bp_data_copy = new float[n_bp * dimension_];
    memcpy(bp_data_copy, bp_data, n_bp * dimension_ * sizeof(float));
    data_bp_ = bp_data_copy;
    data_sq_ = sq_data;
    nd_ = n_bp;
    nd_sq_ = n_sq;
    total_pts_ = nd_ + nd_sq_;
    u32_nd_ = static_cast<uint32_t>(nd_);
    u32_nd_sq_ = static_cast<uint32_t>(nd_sq_);
    u32_total_pts_ = static_cast<uint32_t>(total_pts_);
    
    locks_ = std::vector<std::mutex>(u32_nd_);
    SetRetrAttenParameters(parameters);
    float bipartite_degree_avg = 0;
    size_t bipartite_degree_max = 0, bipartite_degree_min = std::numeric_limits<size_t>::max();
    float projection_degree_avg = 0;
    size_t projection_degree_max = 0, projection_degree_min = std::numeric_limits<size_t>::max();
    size_t i = 0;

    supply_nbrs_.resize(nd_);
    RetrAttenProjectionReserveSpace(parameters);

    CalculateProjectionep();

    assert(projection_ep_ < nd_);
    try {
        LinkProjection(parameters);
    } catch (const std::exception &e) {
        std::cerr << "Caught exception: " << e.what() << std::endl;
    }
    auto e = std::chrono::high_resolution_clock::now();
    auto diff = e - s;
    for (i = 0; i < projection_graph_.size(); ++i) {
        std::vector<uint32_t> &nbrs = projection_graph_[i];
        projection_degree_avg += static_cast<float>(nbrs.size());
        projection_degree_max = std::max(projection_degree_max, nbrs.size());
        projection_degree_min = std::min(projection_degree_min, nbrs.size());
    }
    for (size_t i = 0; i < supply_nbrs_.size(); ++i) {
        std::vector<uint32_t>().swap(supply_nbrs_[i]);

        projection_graph_[i].shrink_to_fit();
        
    }
    std::vector<std::vector<uint32_t>>().swap(supply_nbrs_);
    locks_.clear();
    locks_.shrink_to_fit();
    malloc_trim(0);
    has_built = true;

}

void IndexRetrAtten::BuildRAIndexwithDatanoConn(size_t n_sq, const float *sq_data, size_t n_bp, const float *bp_data,
                                    const Parameters &parameters) {
    
    auto s = std::chrono::high_resolution_clock::now();
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    data_bp_ = bp_data;
    data_sq_ = sq_data;
    nd_ = n_bp;
    nd_sq_ = n_sq;
    total_pts_ = nd_ + nd_sq_;
    u32_nd_ = static_cast<uint32_t>(nd_);
    u32_nd_sq_ = static_cast<uint32_t>(nd_sq_);
    u32_total_pts_ = static_cast<uint32_t>(total_pts_);
    
    locks_ = std::vector<std::mutex>(u32_nd_);
    SetRetrAttenParameters(parameters);
    if (need_normalize) {
        std::cout << "normalizing base data" << std::endl;
        for (size_t i = 0; i < nd_; ++i) {
            float *data = const_cast<float *>(data_bp_);
            normalize(data + i * dimension_, dimension_);
        }
    }
    float bipartite_degree_avg = 0;
    size_t bipartite_degree_max = 0, bipartite_degree_min = std::numeric_limits<size_t>::max();
    float projection_degree_avg = 0;
    size_t projection_degree_max = 0, projection_degree_min = std::numeric_limits<size_t>::max();
    size_t i = 0;
    RetrAttenProjectionReserveSpace(parameters);

    CalculateProjectionep();

    assert(projection_ep_ < nd_);
    
    LinkProjectionNoConn(parameters);
    auto e = std::chrono::high_resolution_clock::now();
    auto diff = e - s;
    for (i = 0; i < projection_graph_.size(); ++i) {
        std::vector<uint32_t> &nbrs = projection_graph_[i];
        projection_degree_avg += static_cast<float>(nbrs.size());
        projection_degree_max = std::max(projection_degree_max, nbrs.size());
        projection_degree_min = std::min(projection_degree_min, nbrs.size());
    }
    #pragma omp parallel for
    for (size_t i = 0; i < supply_nbrs_.size(); ++i) {
        supply_nbrs_[i].clear();
        supply_nbrs_[i].shrink_to_fit();
        projection_graph_[i].shrink_to_fit();
    }

    supply_nbrs_.clear();
    supply_nbrs_.shrink_to_fit();
    locks_.clear();
    locks_.shrink_to_fit();
    malloc_trim(0);
    has_built = true;
}

void IndexRetrAtten::qbaseNNbipartite(const Parameters &parameters) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    uint32_t L_pjpq = parameters.Get<uint32_t>("L_pjpq");

    omp_set_num_threads(parameters.Get<uint32_t>("num_threads"));

    std::vector<uint32_t> vis_order;
    std::vector<uint32_t> vis_order_sq;
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        vis_order.push_back(i);
    }
    for (uint32_t i = 0; i < u32_nd_sq_; ++i) {
        vis_order_sq.push_back(i);
    }

    bipartite_graph_.resize(u32_total_pts_);

#pragma omp parallel for schedule(static, 100)
    for (uint32_t it_sq = 0; it_sq < u32_nd_sq_; ++it_sq) {
        uint32_t sq = vis_order_sq[it_sq];

        auto &nn_base = learn_base_knn_[sq];
        if (nn_base.size() > M_pjbp) {
            nn_base.resize(M_pjbp);
            nn_base.shrink_to_fit();
        }
        uint32_t choose_tgt = 0;

        uint32_t cur_tgt = nn_base[choose_tgt];
        for (size_t i = 0; i < nn_base.size(); ++i) {
            if (nn_base[i] == cur_tgt) {
                continue;
            }
            bipartite_graph_[sq + u32_nd_].push_back(nn_base[i]);
        }
        {
            LockGuard guard(locks_[cur_tgt]);
            bipartite_graph_[cur_tgt].push_back(sq + u32_nd_);
        }
        if (sq % 1000 == 0) {
            std::cout << "\r" << (100.0 * sq) / u32_nd_sq_ << "% of save bipartite graph finish"
                      << std::flush;
        }
    }

}

std::pair<uint32_t, uint32_t> IndexRetrAtten::SearchRetrAttenGraph(const float *query, size_t k, size_t &qid, const Parameters &parameters,
                                              unsigned *indices, std::vector<float>& dists) {
    uint32_t L_pq = parameters.Get<uint32_t>("L_pq");
    NeighborPriorityQueue search_queue(L_pq);
    search_queue.reserve(L_pq);
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint32_t> dis(0, u32_nd_ - 1);
    std::vector<uint32_t> init_ids;
    for (uint32_t i = 0; i < 10; ++i) {
        uint32_t start = dis(gen);  
        init_ids.push_back(start);
    }
    
    VisitedList *vl = visited_list_pool_->getFreeVisitedList();
    vl_type *visited_array = vl->mass;
    vl_type visited_array_tag = vl->curV;
    for (auto &id : init_ids) {
        float distance;
        
        distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
        visited_array[id] = visited_array_tag;
        
    }

    uint32_t cmps = 0;
    uint32_t hops = 0;
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        hops++;
        if (bipartite_graph_[cur_id].size() == 0) {
            continue;
        }
        for (auto nbr : bipartite_graph_[cur_id]) {  
            
            for (auto ns_nbr : bipartite_graph_[nbr]) {  
                if (visited_array[ns_nbr] == visited_array_tag) {
                    continue;
                }

                visited_array[ns_nbr] = visited_array_tag;

                float distance;

                distance = distance_->compare(data_bp_ + ns_nbr * dimension_, query, (unsigned)dimension_);

                ++cmps;
                Neighbor nn = Neighbor(ns_nbr, distance, false);
                search_queue.insert(nn);
            }
        }
    }

    visited_list_pool_->releaseVisitedList(vl);
    if (search_queue.size() < k) {
        std::stringstream ss;
        ss << "not enough results: " << search_queue.size() << ", expected: " << k;
        throw std::runtime_error(ss.str());
    }

    for (size_t i = 0; i < k; ++i) {
        indices[i] = search_queue[i].id;
        dists[i] = search_queue[i].distance;
    }
    return std::make_pair(cmps, hops);
}

void IndexRetrAtten::LinkOneNode(const Parameters &parameters, uint32_t nid, SimpleNeighbor *simple_graph, bool is_base,
                                 boost::dynamic_bitset<> &visited) {
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    uint32_t M_bp = parameters.Get<uint32_t>("M_bp");
    uint32_t L_pq = parameters.Get<uint32_t>("L_pq");

    const float *cur_data = is_base ? data_bp_ : data_sq_;
    uint32_t global_id = is_base ? nid : u32_nd_ + nid;
    
    NeighborPriorityQueue search_queue(L_pq);
    search_queue.reserve(L_pq);

    std::vector<Neighbor> full_retset;
    full_retset.reserve(L_pq * 2);
    if (is_base) {
        SearchRetrAttenbyBase(cur_data + nid * dimension_, global_id, parameters, simple_graph, search_queue, visited,
                              full_retset);
        std::vector<uint32_t> pruned_list;
        
        PruneCandidates(full_retset, global_id, parameters, pruned_list, visited);
        if (search_queue.size() <= 0) {
            throw std::runtime_error("search queue is empty");
        }

        {
            LockGuard guard(locks_[global_id]);
            bipartite_graph_[global_id].reserve(M_sq * 1.5);
            bipartite_graph_[global_id] = pruned_list;
        }
        if (sq_en_set_.size() < 200) {
            for (size_t i = 0; i < pruned_list.size(); ++i) {
                if (sq_en_set_.find(pruned_list[i]) == sq_en_set_.end()) {
                    LockGuard guard(sq_set_mutex_);
                    sq_en_set_.insert(pruned_list[i]);
                }
            }
        }

        if (bp_en_set_.size() < 200) {
            {
                if (bp_en_set_.find(global_id) == bp_en_set_.end()) {
                    LockGuard guard(bp_set_mutex_);
                    bp_en_set_.insert(global_id);
                }
            }
        }
        AddReverse(search_queue, global_id, pruned_list, parameters, visited);
    } else {
        SearchRetrAttenbyQuery(cur_data + nid * dimension_, global_id, parameters, simple_graph, search_queue, visited,
                               full_retset);
        std::vector<uint32_t> pruned_list;
        
        PruneCandidates(full_retset, global_id, parameters, pruned_list, visited);

        if (search_queue.size() <= 0) {
            throw std::runtime_error("search queue is empty");
        }

        {
            LockGuard guard(locks_[global_id]);
            bipartite_graph_[global_id].reserve(M_bp * 1.5);
            bipartite_graph_[global_id] = pruned_list;
        }

        if (bp_en_set_.size() < 100) {
            for (size_t i = 0; i < pruned_list.size(); ++i) {
                if (bp_en_set_.find(pruned_list[i]) == bp_en_set_.end()) {
                    LockGuard guard(bp_set_mutex_);
                    bp_en_set_.insert(pruned_list[i]);
                }
            }
        }

        if (sq_en_set_.size() < 100) {
            {
                if (sq_en_set_.find(global_id) == sq_en_set_.end()) {
                    LockGuard guard(sq_set_mutex_);
                    sq_en_set_.insert(global_id);
                }
            }
        }
        AddReverse(search_queue, global_id, pruned_list, parameters, visited);
    }
}

void IndexRetrAtten::LinkRetrAtten(const Parameters &parameters, SimpleNeighbor *simple_graph) {
    
    uint32_t M_bp = parameters.Get<uint32_t>("M_bp");
    omp_set_num_threads(static_cast<int>(parameters.Get<uint32_t>("num_threads")));

    std::vector<uint32_t> running_order;
    std::vector<uint32_t> indicate_idx;
    size_t prepare_i = 0;
    for (; prepare_i < total_pts_; prepare_i += 2) {
        indicate_idx.push_back(static_cast<uint32_t>(prepare_i));
    }

    uint32_t i_bp = 0, j_sq = 0;
    while (i_bp + j_sq < u32_total_pts_) {
        if (i_bp < u32_nd_) {
            running_order.push_back(i_bp);
            ++i_bp;
        }
        if (j_sq < u32_nd_sq_) {
            running_order.push_back(u32_nd_ + j_sq);
            ++j_sq;
        }
    }
    int for_n = (int)indicate_idx.size();

#pragma omp parallel for schedule(dynamic, 100)
    for (int iter_loc = 0; iter_loc < for_n; ++iter_loc) {
        boost::dynamic_bitset<> visited(u32_total_pts_);
        visited.reset();
        visited.reserve(u32_total_pts_);
        uint32_t val = running_order[indicate_idx[iter_loc]];
        if (val < u32_nd_) {
            
            LinkOneNode(parameters, val, simple_graph, true, visited);
        } else {
            LinkOneNode(parameters, val - u32_nd_, simple_graph, false, visited);
        }
        if (indicate_idx[iter_loc] + 1 < running_order.size()) {
            val = running_order[indicate_idx[iter_loc] + 1];
        } else {
            std::cout << "haha" << std::endl;
            continue;
        }
        if (val >= u32_nd_) {
            
            LinkOneNode(parameters, val - u32_nd_, simple_graph, false, visited);
        } else {
            LinkOneNode(parameters, val, simple_graph, true, visited);
        }
        if (running_order[indicate_idx[iter_loc]] % 1000 == 0) {
            std::cout << "\r" << (100.0 * (val - u32_nd_)) / (u32_nd_sq_) << "% of index build completed."
                      << "val: " << val << std::flush;
        }
    }
    std::cout << std::endl;
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = 0; i < u32_total_pts_; ++i) {
        boost::dynamic_bitset<> visited(u32_total_pts_);
        NeighborPriorityQueue useless_queue;
        AddReverse(useless_queue, i, bipartite_graph_[i], parameters, visited);
    }

#pragma omp parallel for schedule(dynamic, 100)
    for (size_t i = 0; i < bipartite_graph_.size(); ++i) {
        if (bipartite_graph_[i].size() < M_bp) {
            
            boost::dynamic_bitset<> visited(u32_total_pts_);
            if (i < nd_) {
                LinkOneNode(parameters, i, simple_graph, true, visited);
            } else {
                LinkOneNode(parameters, i - nd_, simple_graph, false, visited);
            }
        }
    }
}

void IndexRetrAtten::PruneCandidates(std::vector<Neighbor> &search_pool, uint32_t tgt_id, const Parameters &parameters,
                                     std::vector<uint32_t> &pruned_list, boost::dynamic_bitset<> &visited) {
    
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    uint32_t M_bp = parameters.Get<uint32_t>("M_bp");

    std::sort(search_pool.begin(), search_pool.end());
    uint32_t degree_bound = tgt_id < u32_nd_ ? M_bp : M_sq;
    pruned_list.reserve(degree_bound * 1.5);
    auto reachable_flags = visited;
    reachable_flags.reset();
    for (size_t i = 0; i < search_pool.size(); ++i) {
        if (pruned_list.size() >= (size_t)degree_bound) {
            break;
        }
        auto &node = search_pool[i];
        auto id = node.id;
        if (!reachable_flags.test(id)) {
            reachable_flags.set(id);
            pruned_list.push_back(id);
            
            for (auto nbr : bipartite_graph_[id]) {
                for (auto nnbr : bipartite_graph_[nbr]) {
                    reachable_flags.set(nnbr);
                }
            }
        }
    }
    
    if (pruned_list.size() < (size_t)degree_bound) {
        
        for (size_t i = 0; i < search_pool.size(); ++i) {
            if (pruned_list.size() >= (size_t)degree_bound) {
                break;
            }
            if (std::find(pruned_list.begin(), pruned_list.end(), search_pool[i].id) == pruned_list.end()) {
                pruned_list.push_back(search_pool[i].id);
            }
        }
    }
}

void IndexRetrAtten::AddReverse(NeighborPriorityQueue &search_pool, uint32_t src_node,
                                std::vector<uint32_t> &pruned_list, const Parameters &parameters,
                                boost::dynamic_bitset<> &visited) {
    uint32_t L_pq = parameters.Get<uint32_t>("L_pq");
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    uint32_t M_bp = parameters.Get<uint32_t>("M_bp");
    bool is_base = src_node < u32_nd_;  

    uint32_t degree_bound = is_base ? M_sq : M_bp;  
    bool need_prune = false;

    for (size_t i = 0; i < pruned_list.size(); ++i) {  
        auto cur_node = pruned_list[i];
        std::vector<uint32_t> copy_vec;
        copy_vec.reserve(degree_bound * 1.5);
        copy_vec = bipartite_graph_[cur_node];
        if (std::find(copy_vec.begin(), copy_vec.end(), src_node) == copy_vec.end()) {
            if (copy_vec.size() < degree_bound) {
                need_prune = false;
                {
                    LockGuard gurad(locks_[cur_node]);
                    bipartite_graph_[cur_node].push_back(src_node);
                }
            } else {
                need_prune = true;
            }
        }

        if (need_prune) {
            const float *cur_data = is_base ? data_bp_ : data_sq_;
            const float *opposite_data = is_base ? data_sq_ : data_bp_;
            std::vector<Neighbor> simulate_pool;
            simulate_pool.reserve(copy_vec.size());
            for (auto id : copy_vec) {
                uint32_t cate_id = is_base ? id : id - u32_nd_;
                uint32_t cate_cur_node = is_base ? cur_node - u32_nd_ : cur_node;
                
                float distance = distance_->compare(cur_data + cate_id * dimension_,
                                                    opposite_data + cate_cur_node * dimension_, dimension_);
                simulate_pool.push_back(Neighbor(id, distance, false));
            }
            std::vector<uint32_t> inside_pruned_list;
            PruneCandidates(simulate_pool, cur_node, parameters, inside_pruned_list, visited);
            copy_vec = inside_pruned_list;
            {
                LockGuard gurad(locks_[cur_node]);
                
                bipartite_graph_[cur_node] = copy_vec;
            }
        }
    }
}
void IndexRetrAtten::SearchRetrAttenbyBase(const float *query, uint32_t gid, const Parameters &parameters,
                                           SimpleNeighbor *simple_graph, NeighborPriorityQueue &queue,
                                           boost::dynamic_bitset<> &visited, std::vector<Neighbor> &full_retset) {
    uint32_t M_sq = parameters.Get<uint32_t>("M_sq");
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint32_t> dis(0, u32_nd_sq_ - 1);
    
    std::vector<uint32_t> init_ids;
    uint32_t start = dis(gen);
    
    init_ids.push_back(u32_nd_ + start);  
    init_ids.push_back(u32_nd_ + dis(gen));
    if (!sq_en_set_.empty()) {  
        std::random_device rd;
        std::mt19937 gen(rd());
        std::uniform_int_distribution<uint32_t> dis(0, sq_en_set_.size() - 1);
        {
            start = *(std::next(sq_en_set_.begin(), dis(gen)));
        }

        if (start < u32_nd_) {
            std::cout << "Error: start less than sampled query no, exited. start: " << start << " advance: "
                      << "sq_en_set_ size(): " << sq_en_set_.size() << std::endl;
            for (auto &i : sq_en_set_) {
                std::cout << i << " ";
            }
            exit(1);
        }
        if (init_ids[0] != start) {
            init_ids.push_back(start);
        }
    }

    while (init_ids.size() < queue.capacity()) {
        init_ids.push_back(u32_nd_ + dis(gen));
    }
    for (auto id : init_ids) {
        float distance;
        uint32_t cur_local_id = id - u32_nd_;
        
        distance = distance_->compare(data_sq_ + cur_local_id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        queue.insert(nn);
        visited.set(id);
        full_retset.push_back(nn);
        
    }

    while (queue.has_unexpanded_node()) {
        auto cur_check_node = queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;

        std::vector<uint32_t> nbr_ids;
        bool fast_check = false;
        uint32_t first_hop_rank_1 = bipartite_graph_[cur_id].size() == 0 ? rand() % nd_ : bipartite_graph_[cur_id][0];
        if (bipartite_graph_[cur_id].size() < M_sq) {
            fast_check = false;
        }
        float first_hop_min_dist = 1000;
        for (size_t j = 0; j < bipartite_graph_[cur_id].size(); ++j) {
            
            uint32_t nbr = bipartite_graph_[cur_id][j];
            if (nbr == gid) {
                continue;
            }
            for (size_t i = 0; i < bipartite_graph_[nbr].size(); ++i) {
                if (visited.test(bipartite_graph_[nbr][i])) {
                    continue;
                }
                uint32_t cate_id = bipartite_graph_[nbr][i] - u32_nd_;
                float distance = distance_->compare(data_sq_ + cate_id * dimension_, query, (unsigned)dimension_);

                if (fast_check) {
                    if (first_hop_min_dist > distance) {
                        first_hop_min_dist = distance;
                        first_hop_rank_1 = nbr;
                    }
                }

                visited.set(bipartite_graph_[nbr][i]);
                Neighbor nn = Neighbor(bipartite_graph_[nbr][i], distance, false);
                queue.insert(nn);
                full_retset.push_back(nn);
                if (fast_check) {
                    break;
                }
            }
        }

        if (fast_check) {
            for (size_t i = 0; i < bipartite_graph_[first_hop_rank_1].size(); ++i) {
                auto nbr = bipartite_graph_[first_hop_rank_1][i];
                if (visited.test(nbr)) {
                    continue;
                }
                uint32_t cate_id = nbr - u32_nd_;
                float distance = distance_->compare(data_sq_ + cate_id * dimension_, query, (unsigned)dimension_);

                visited.set(nbr);
                Neighbor nn = Neighbor(nbr, distance, false);
                queue.insert(nn);
                full_retset.push_back(nn);
            }
        }
    }
}
void IndexRetrAtten::SearchRetrAttenbyQuery(const float *query, uint32_t gid, const Parameters &parameters,
                                            SimpleNeighbor *simple_graph, NeighborPriorityQueue &queue,
                                            boost::dynamic_bitset<> &visited, std::vector<Neighbor> &full_retset) {
    
    uint32_t M_bp = parameters.Get<uint32_t>("M_bp");
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint32_t> dis(0, u32_nd_ - 1);
    uint32_t start = dis(gen);
    
    std::vector<uint32_t> init_ids;
    init_ids.push_back(start);
    init_ids.push_back(dis(gen));
    if (!bp_en_set_.empty()) {
        std::random_device rd;
        std::mt19937 gen(rd());
        {
            std::uniform_int_distribution<uint32_t> dis(0, bp_en_set_.size() - 1);
            
            LockGuard guard(bp_set_mutex_);
            start = *(std::next(bp_en_set_.begin(), dis(gen)));
        }
        if (start > u32_nd_) {
            std::cout << "Error: start greater than base no, exited" << std::endl;
            exit(1);
        }
        if (init_ids[0] != start) {
            init_ids.push_back(start);
        }
    }

    while (init_ids.size() < queue.capacity()) {
        init_ids.push_back(dis(gen));
    }

    for (auto &id : init_ids) {
        float distance;

        distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);

        Neighbor nn = Neighbor(id, distance, false);
        visited.set(id);
        queue.insert(nn);
        full_retset.push_back(nn);
    }

    while (queue.has_unexpanded_node()) {
        auto cur_check_node = queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;

        bool fast_check = false;
        uint32_t first_hop_rank_1 =
            bipartite_graph_[cur_id].size() == 0 ? (rand() % nd_sq_) + nd_ : bipartite_graph_[cur_id][0];
        if (bipartite_graph_[cur_id].size() < (M_bp)) {
            fast_check = false;
        }
        float first_hop_min_dist = 1000;

        for (size_t j = 0; j < bipartite_graph_[cur_id].size(); ++j) {
            uint32_t nbr = bipartite_graph_[cur_id][j];
            if (nbr == gid) {
                continue;
            }
            for (size_t i = 0; i < bipartite_graph_[nbr].size(); ++i) {
                uint32_t ns_nbr = bipartite_graph_[nbr][i];
                if (visited.test(ns_nbr)) {
                    continue;
                }

                float distance = distance_->compare(data_bp_ + ns_nbr * dimension_, query, (unsigned)dimension_);

                if (fast_check) {
                    if (first_hop_min_dist > distance) {
                        first_hop_min_dist = distance;
                        first_hop_rank_1 = nbr;
                    }
                }
                Neighbor nn = Neighbor(ns_nbr, distance, false);
                queue.insert(nn);
                visited.set(nbr);
                full_retset.push_back(nn);

                if (fast_check) {
                    break;
                }
            }
        }

        if (fast_check) {
            for (size_t i = 0; i < bipartite_graph_[first_hop_rank_1].size(); ++i) {
                auto nbr = bipartite_graph_[first_hop_rank_1][i];
                if (visited.test(nbr)) {
                    continue;
                }

                float distance = distance_->compare(data_bp_ + nbr * dimension_, query, (unsigned)dimension_);

                visited.set(nbr);
                Neighbor nn = Neighbor(nbr, distance, false);
                queue.insert(nn);
                full_retset.push_back(nn);
            }
        }
    }
}

void IndexRetrAtten::PruneLocalJoinCandidates(uint32_t node, const Parameters &parameters, uint32_t candi) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");

    NeighborPriorityQueue search_pool;
    search_pool.reserve(M_pjbp + 1);

    for (auto nbr : projection_graph_[node]) {
        if (nbr == node) {
            continue;
        }
        float distance = distance_->compare(data_bp_ + dimension_ * nbr, data_bp_ + dimension_ * node, dimension_);
        search_pool.insert({nbr, distance, false});
    }

    float distance = distance_->compare(data_bp_ + dimension_ * candi, data_bp_ + dimension_ * node, dimension_);
    search_pool.insert({candi, distance, false});

    std::vector<uint32_t> result;
    result.reserve(M_pjbp * PROJECTION_SLACK);
    uint32_t start = 0;
    if (search_pool[start].id == node) {
        start++;
    }
    result.push_back(search_pool[start].id);
    ++start;

    while (result.size() < M_pjbp && (++start) < search_pool.size()) {
        Neighbor &p = search_pool[start];
        bool occlude = false;
        for (size_t t = 0; t < result.size(); ++t) {
            if (p.id == result[t]) {
                occlude = true;
                break;
            }
            float djk = distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * result[t], dimension_);
            if (djk < p.distance) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p.id != node) {
                result.push_back(p.id);
            }
        }
    }

    for (size_t i = 0; i < search_pool.size() && result.size() < M_pjbp; ++i) {
        if (std::find(result.begin(), result.end(), search_pool[i].id) == result.end()) {
            result.push_back(search_pool[i].id);
        }
    }

    {
        LockGuard guard(locks_[node]);
        projection_graph_[node] = result;
    }
}

void IndexRetrAtten::RetrAttenProjectionReserveSpace(const Parameters &parameters) {
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    projection_graph_.resize(u32_nd_);
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        projection_graph_[i].reserve(M_pjbp * PROJECTION_SLACK);
    }
}

void IndexRetrAtten::TrainingLink2Projection(const Parameters &parameters, SimpleNeighbor *simple_graph) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    uint32_t L_pjpq = parameters.Get<uint32_t>("L_pjpq");
    omp_set_num_threads(parameters.Get<uint32_t>("num_threads"));
    std::vector<uint32_t> vis_order;
    std::vector<uint32_t> vis_order_sq;
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        vis_order.push_back(i);
    }
    for (uint32_t i = 0; i < u32_nd_sq_; ++i) {
        vis_order_sq.push_back(i);
    }

    #pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t it_sq = 0; it_sq < u32_nd_sq_; ++it_sq) {
        uint32_t sq = vis_order_sq[it_sq];
        
        auto &nn_base = learn_base_knn_[sq];
        if (nn_base.size() > 100) {
            nn_base.resize(100);
            nn_base.shrink_to_fit();
        }
        uint32_t choose_tgt = 0;

        uint32_t cur_tgt = nn_base[choose_tgt];
        std::vector<Neighbor> full_retset;
        for (size_t i = 0; i < nn_base.size(); ++i) {
            if (nn_base[i] == cur_tgt) {
                continue;
            }
            float distance = distance_->compare(data_bp_ + dimension_ * (uint64_t)nn_base[i], data_bp_ + dimension_ * (uint64_t)cur_tgt,
                                                (unsigned)dimension_);
            full_retset.push_back(Neighbor(nn_base[i], distance, false));
        }
        std::vector<uint32_t> pruned_list;
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        PruneBiSearchBaseGetBase(full_retset, data_bp_ + dimension_ * cur_tgt, cur_tgt, parameters, pruned_list);
        {
            LockGuard guard(locks_[cur_tgt]);
            projection_graph_[cur_tgt] = pruned_list;
        }
        ProjectionAddReverse(cur_tgt, parameters);
        if (sq % 1000 == 0) {
            std::cout << "\r" << (100.0 * sq) / u32_nd_sq_ << "% of projection search bipartite by base completed."
                      << std::flush;
        }
    }

    std::cout << std::endl;
    std::atomic<uint32_t> degree_cnt(0);
    std::atomic<uint32_t> zero_cnt(0);
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = 0; i < vis_order.size(); ++i) {
        uint32_t node = vis_order[i];
        if (projection_graph_[node].size() < M_pjbp) {
            
            degree_cnt.fetch_add(1);
            if (projection_graph_[node].size() == 0) {
                zero_cnt.fetch_add(1);
            }
        }
        ProjectionAddReverse(node, parameters);
    }
    std::cout << "Warning: " << degree_cnt.load() << " nodes have less than M_pjbp neighbors." << std::endl;
    std::cout << "Warning: " << zero_cnt.load() << " nodes have no neighbors." << std::endl;

#pragma omp parallel for schedule(dynamic, 100)
    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        std::vector<uint32_t> ok_insert;
        ok_insert.reserve(M_pjbp);
        for (size_t j = 0; j < supply_nbrs_[i].size(); ++j) {
            if (ok_insert.size() >= M_pjbp * 2) {
                break;
            }
            if (std::find(projection_graph_[i].begin(), projection_graph_[i].end(), supply_nbrs_[i][j]) ==
                projection_graph_[i].end()) {
                ok_insert.push_back(supply_nbrs_[i][j]);
            }
        }
        projection_graph_[i].insert(projection_graph_[i].end(), ok_insert.begin(), ok_insert.end());
    }
}

void IndexRetrAtten::LinkProjection(const Parameters &parameters) {
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    uint32_t L_pjpq = parameters.Get<uint32_t>("L_pjpq");
    uint32_t Nq = parameters.Get<uint32_t>("M_sq");

    omp_set_num_threads(parameters.Get<uint32_t>("num_threads"));
    std::vector<uint32_t> vis_order(u32_nd_);
    std::vector<uint32_t> vis_order_sq(u32_nd_sq_);
    
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        vis_order[i] = i;
    }
    for (uint32_t i = 0; i < u32_nd_sq_; ++i) {
        vis_order_sq[i] = i;
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t it_sq = 0; it_sq < u32_nd_sq_; ++it_sq) {
        uint32_t sq = vis_order_sq[it_sq];

        auto &nn_base = learn_base_knn_[sq];
        if (nn_base.size() > Nq) {
            nn_base.resize(Nq);
            nn_base.shrink_to_fit();
        }
        uint32_t choose_tgt = 0;
        uint32_t cur_tgt = nn_base[choose_tgt];
        std::vector<Neighbor> full_retset;
        for (size_t i = 0; i < nn_base.size(); ++i) {
            if (nn_base[i] == cur_tgt) {
                continue;
            }
            
            _mm_prefetch((const char *)(data_bp_ + dimension_ * (uint64_t)nn_base[i]), _MM_HINT_T0);
            float distance = l2_distance_->compare(data_bp_ + dimension_ * (uint64_t)nn_base[i], data_bp_ + dimension_ * (uint64_t)cur_tgt,
                                                (unsigned)dimension_);
            full_retset.push_back(Neighbor(nn_base[i], distance, false));
        }
        std::vector<uint32_t> pruned_list;
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        PruneBiSearchBaseGetBase(full_retset, data_bp_ + dimension_ * cur_tgt, cur_tgt, parameters, pruned_list);
        {
            LockGuard guard(locks_[cur_tgt]);
            projection_graph_[cur_tgt] = pruned_list;
            projection_graph_[cur_tgt].reserve(M_pjbp * PROJECTION_SLACK);
        }
        ProjectionAddReverse(cur_tgt, parameters);
    }
    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        
        supply_nbrs_[i] = projection_graph_[i];
        supply_nbrs_[i].reserve(M_pjbp * 2 * PROJECTION_SLACK);
        
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        uint32_t node = vis_order[i];
        boost::dynamic_bitset<> visited{u32_nd_, false};
        std::vector<Neighbor> full_retset;
        full_retset.reserve(L_pjpq);
        NeighborPriorityQueue search_pool;
        SearchProjectionGraphInternal(search_pool, data_bp_ + dimension_ * node, node, parameters, visited,
                                      full_retset);
        std::vector<uint32_t> pruned_list;
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        PruneProjectionBaseSearchCandidates(full_retset, data_bp_ + dimension_ * node, node, parameters, pruned_list);
        {
            LockGuard guard(locks_[node]);
            if (pruned_list.size() != 0) {
                supply_nbrs_[node] = pruned_list;
                
            }
        }
        SupplyAddReverse(node, parameters);
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        std::vector<uint32_t> ok_insert;
        ok_insert.reserve(M_pjbp);
        for (size_t j = 0; j < supply_nbrs_[i].size(); ++j) {
            if (ok_insert.size() >= M_pjbp * 2) {
                break;
            }
            if (std::find(projection_graph_[i].begin(), projection_graph_[i].end(), supply_nbrs_[i][j]) ==
                projection_graph_[i].end()) {
                ok_insert.push_back(supply_nbrs_[i][j]);
            }
        }
        projection_graph_[i].insert(projection_graph_[i].end(), ok_insert.begin(), ok_insert.end());
    }
}

void IndexRetrAtten::LinkProjectionNoConn(const Parameters &parameters) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    uint32_t L_pjpq = parameters.Get<uint32_t>("L_pjpq");
    uint32_t Nq = parameters.Get<uint32_t>("M_sq");

    omp_set_num_threads(parameters.Get<uint32_t>("num_threads"));
    std::vector<uint32_t> vis_order;
    std::vector<uint32_t> vis_order_sq;
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        vis_order.push_back(i);
    }
    for (uint32_t i = 0; i < u32_nd_sq_; ++i) {
        vis_order_sq.push_back(i);
    }
    std::chrono::high_resolution_clock::time_point t1 = std::chrono::high_resolution_clock::now();

#pragma omp parallel for schedule(static, 100)
    for (uint32_t it_sq = 0; it_sq < u32_nd_sq_; ++it_sq) {
        uint32_t sq = vis_order_sq[it_sq];

        auto &nn_base = learn_base_knn_[sq];
        if (nn_base.size() > Nq) {
            nn_base.resize(Nq);
            nn_base.shrink_to_fit();
        }
        uint32_t choose_tgt = 0;
        uint32_t cur_tgt = nn_base[choose_tgt];
        std::vector<Neighbor> full_retset;
        for (size_t i = 0; i < nn_base.size(); ++i) {
            if (nn_base[i] == cur_tgt) {
                continue;
            }
            float distance = distance_->compare(data_bp_ + dimension_ * (uint64_t)nn_base[i], data_bp_ + dimension_ * (uint64_t)cur_tgt,
                                                (unsigned)dimension_);
            full_retset.push_back(Neighbor(nn_base[i], distance, false));
        }
        std::vector<uint32_t> pruned_list;
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        PruneBiSearchBaseGetBase(full_retset, data_bp_ + dimension_ * cur_tgt, cur_tgt, parameters, pruned_list);
        {
            LockGuard guard(locks_[cur_tgt]);
            projection_graph_[cur_tgt] = pruned_list;
        }
        ProjectionAddReverse(cur_tgt, parameters);
        if (sq % 1000 == 0) {
            std::cout << "\r" << (100.0 * sq) / u32_nd_sq_ << "% of projection search bipartite by base completed."
                      << std::flush;
        }
    }

    std::cout << std::endl;
#pragma omp parallel for schedule(static, 100)
    for (uint32_t i = 0; i < vis_order.size(); ++i) {
        uint32_t node = vis_order[i];
        ProjectionAddReverse(node, parameters);
    }
#pragma omp parallel for schedule(static, 2048)
    for (uint32_t i = 0; i < vis_order.size(); ++i) {
        size_t node = (size_t)vis_order[i];
        if (projection_graph_[node].size() > M_pjbp) {
            std::vector<Neighbor> full_retset;
            tsl::robin_set<uint32_t> visited;
            for (size_t j = 0; j < projection_graph_[node].size(); ++j) {
                if (visited.find(projection_graph_[node][j]) != visited.end()) {
                    continue;
                }
                float distance = distance_->compare(data_bp_ + dimension_ * (size_t)projection_graph_[node][j],
                                                    data_bp_ + dimension_ * (size_t)node, dimension_);
                visited.insert(projection_graph_[node][j]);
                full_retset.push_back(Neighbor(projection_graph_[node][j], distance, false));
            }
            for (unsigned j = 0; j < full_retset.size(); j++) {
                if (full_retset[j].id == (unsigned)node) {
                    full_retset.erase(full_retset.begin() + j);
                    j--;
                }
            }
            std::vector<uint32_t> prune_list;
            PruneBiSearchBaseGetBase(full_retset, data_bp_ + dimension_ * (size_t)node, node, parameters, prune_list);
            {
                LockGuard guard(locks_[node]);
                projection_graph_[node].clear();
                projection_graph_[node] = prune_list;
            }
        }
    }

    std::chrono::high_resolution_clock::time_point t2 = std::chrono::high_resolution_clock::now();
    auto projection_time = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();

    std::atomic<uint32_t> degree_cnt(0);
    std::atomic<uint32_t> zero_cnt(0);
#pragma omp parallel for schedule(static, 100)
    for (uint32_t i = 0; i < vis_order.size(); ++i) {
        uint32_t node = vis_order[i];
        if (projection_graph_[node].size() < M_pjbp) {
            
            degree_cnt.fetch_add(1);
            if (projection_graph_[node].size() == 0) {
                zero_cnt.fetch_add(1);
            }
        }
    }
    std::cout << "Projection time: " << projection_time << std::endl;
    std::cout << "Warning: " << degree_cnt.load() << " nodes have less than M_pjbp neighbors." << std::endl;
    std::cout << "Warning: " << zero_cnt.load() << " nodes have no neighbors." << std::endl;
    float avg_degree = 0;
    uint64_t total_degree = 0;
    uint32_t max_degree = 0;
    uint32_t min_degree = std::numeric_limits<uint32_t>::max();
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        if (projection_graph_[i].size() > max_degree) {
            max_degree = projection_graph_[i].size();
        }
        if (projection_graph_[i].size() < min_degree) {
            min_degree = projection_graph_[i].size();
        }
        avg_degree += static_cast<float>(projection_graph_[i].size());
        total_degree += projection_graph_[i].size();
    }
    std::cout << "total degree: " << total_degree << std::endl;
    avg_degree /= (float)u32_nd_;
    std::cout << "After projection, average degree of projection graph: " << avg_degree << std::endl;
    std::cout << "After projection, max degree of projection graph: " << max_degree << std::endl;
    std::cout << "After projection, min degree of projection graph: " << min_degree << std::endl;

    std::cout << std::endl;
}

void IndexRetrAtten::SearchProjectionGraphInternal(NeighborPriorityQueue &search_queue, const float *query,
                                                   uint32_t tgt, const Parameters &parameters,
                                                   boost::dynamic_bitset<> &visited,
                                                   std::vector<Neighbor> &full_retset) {
    uint32_t L_pq = parameters.Get<uint32_t>("L_pjpq");

    search_queue.reserve(L_pq);
    uint32_t start = projection_ep_;
    std::vector<uint32_t> init_ids;
    init_ids.push_back(start);
    if (projection_ep_ < u32_nd_ - 1) {
        init_ids.push_back(projection_ep_ + 1);
    }
    if (projection_ep_ > 0) {
        init_ids.push_back(projection_ep_ - 1);
    }
    for (auto &id : init_ids) {
        float distance;
        if (!check_valid_range(id)) {
            std::cout << "In file " << __FILE__ << " at line " << __LINE__ << " out of range. total_pts:" << total_pts_ << ", cur access: " << id << std::endl << std::flush;
        }

        if (!check_valid_range(tgt)) {
            
            std::cout << "In file " << __FILE__ << " at line " << __LINE__ << " out of range. total_pts:" << total_pts_ << ", cur access: " << tgt << std::endl << std::flush;
        }

        distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
        visited.set(id);
        
    }
    
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        full_retset.push_back(cur_check_node);
        for (auto nbr : supply_nbrs_[cur_id]) {  

            if (!check_valid_range(nbr)) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
                continue;
            }
            
            if (visited.test(nbr) || nbr == tgt) {
                
                continue;
            }
            
            visited.set(nbr);
            float distance = l2_distance_->compare(data_bp_ + nbr * dimension_, query, (unsigned)dimension_);
            Neighbor nn = Neighbor(nbr, distance, false);
            search_queue.insert(nn);
        }
    }
}

void IndexRetrAtten::SupplyAddReverse(uint32_t src_node, const Parameters &parameters) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp") * 2;
    
    std::vector<uint32_t>& nbrs = supply_nbrs_[src_node];

    for (size_t i = 0; i < nbrs.size(); ++i) {
        auto des = nbrs[i];
        auto& des_nbrs = supply_nbrs_[des];
        bool need_prune = false;
        {
            LockGuard guard(locks_[des]);
            if (std::find(des_nbrs.begin(), des_nbrs.end(), src_node) == des_nbrs.end()) {
                if (des_nbrs.size() < M_pjbp) {
                    des_nbrs.push_back(src_node);
                } else {
                    need_prune = true;
                }
            } else {
                continue;
            }
        }
        if (need_prune) {
            std::vector<uint32_t> copy_vec;
            copy_vec.reserve(M_pjbp * PROJECTION_SLACK);
            copy_vec = des_nbrs;
            copy_vec.push_back(src_node);
            PruneProjectionInternalReverseCandidates(des, parameters, copy_vec);
            {
                LockGuard guard(locks_[des]);
                supply_nbrs_[des] = copy_vec;
                
            }
        }
    }
}

void IndexRetrAtten::ProjectionAddReverse(uint32_t src_node, const Parameters &parameters) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    std::vector<uint32_t> &nbrs = projection_graph_[src_node];
    float avg_copy = 0;
    
    for (size_t i = 0; i < nbrs.size(); ++i) {
        auto des = nbrs[i];
        auto &des_nbrs = projection_graph_[des];
        bool need_prune = false;
        {
            LockGuard guard(locks_[des]);
            if (std::find(des_nbrs.begin(), des_nbrs.end(), src_node) == des_nbrs.end()) {
                if (des_nbrs.size() < M_pjbp) {
                    des_nbrs.push_back(src_node);
                } else {
                    need_prune = true;
                }
            } else {
                continue;
            }
        }

        if (need_prune) {
            std::vector<uint32_t> copy_vec;
            copy_vec.reserve(M_pjbp * PROJECTION_SLACK);
            {
                LockGuard guard(locks_[des]);
                for (size_t j = 0; j < des_nbrs.size(); ++j) {
                    copy_vec.push_back(des_nbrs[j]);
                }
            }
            copy_vec.push_back(src_node);
            PruneProjectionReverseCandidates(des, parameters, copy_vec);
            {
                LockGuard guard(locks_[des]);
                projection_graph_[des] = copy_vec;
                
            }
        }
    }
}

void IndexRetrAtten::PruneProjectionInternalReverseCandidates(uint32_t src_node, const Parameters &parameters,
                                                              std::vector<uint32_t> &pruned_list) {

    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp") * 2;
    std::vector<Neighbor> prune_queue(pruned_list.size());
    prune_queue.reserve(pruned_list.size());
    for (size_t i = 0; i < pruned_list.size(); ++i) {
        float distance = distance_->compare(data_bp_ + dimension_ * src_node, data_bp_ + dimension_ * pruned_list[i],
                                            dimension_);

        Neighbor nn = Neighbor(pruned_list[i], distance, false);
        if (std::find(prune_queue.begin(), prune_queue.end(), nn) == prune_queue.end()) {
            prune_queue.push_back(nn);
        }
        
    }
    std::sort(prune_queue.begin(), prune_queue.end());
    
    std::vector<uint32_t> result;
    result.reserve(M_pjbp  * PROJECTION_SLACK);

    uint32_t start = 0;
    if (prune_queue[start].id == src_node) {
        start++;
    }
    if (start >= prune_queue.size()) {
        return;
    }
    result.push_back(prune_queue[start].id);
    while (result.size() < M_pjbp && (++start) < prune_queue.size()) {
        auto &p = prune_queue[start];
        bool occlude = false;
        for (size_t i = 0; i < result.size(); ++i) {
            if (p.id == result[i]) {
                occlude = true;
                break;
            }

            if (!check_valid_range(p.id)) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
            }

            if (!check_valid_range(result[i])) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
            }
            float djk = l2_distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * result[i],
                                           dimension_);
            if (1 * djk < p.distance) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p.id != src_node) {
                result.push_back(p.id);
            }
        }
    }
    pruned_list = result;
}

void IndexRetrAtten::PruneProjectionReverseCandidates(uint32_t src_node, const Parameters &parameters,
                                                      std::vector<uint32_t> &pruned_list) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    std::vector<Neighbor> prune_queue;
    prune_queue.reserve(pruned_list.size());
    for (size_t i = 0; i < pruned_list.size(); ++i) {

        if (!check_valid_range(src_node)) {
            std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
        }

        if (!check_valid_range(pruned_list[i])) {
            std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
        }

        float distance = distance_->compare(data_bp_ + dimension_ * src_node, data_bp_ + dimension_ * pruned_list[i],
                                            (unsigned)dimension_);
        Neighbor nn = Neighbor(pruned_list[i], distance, false);
        if (std::find(prune_queue.begin(), prune_queue.end(), nn) == prune_queue.end()) {
            prune_queue.push_back(nn);
        }
    }
    std::sort(prune_queue.begin(), prune_queue.end());
    
    std::vector<uint32_t> result;
    result.reserve(M_pjbp * PROJECTION_SLACK);
    uint32_t start = 0;
    if (prune_queue[start].id == src_node) {
        start++;
    }
    result.push_back(prune_queue[start].id);
    while (result.size() < M_pjbp && (++start) < prune_queue.size()) {
        auto &p = prune_queue[start];
        bool occlude = false;
        for (size_t i = 0; i < result.size(); ++i) {
            if (p.id == result[i]) {
                occlude = true;
                break;
            }

            if (!check_valid_range(p.id)) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
            }

            if (!check_valid_range(result[i])) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
            }
            float djk = l2_distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * result[i],
                                           (unsigned)dimension_);
            if (djk < p.distance) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p.id != src_node) {
                result.push_back(p.id);
            }
        }
    }
    for (size_t i = 0; i < pruned_list.size() && result.size() < M_pjbp; ++i) {
        if (std::find(result.begin(), result.end(), pruned_list[i]) == result.end()) {
            result.push_back(pruned_list[i]);
        }
    }
    pruned_list = result;
}

void IndexRetrAtten::PruneBiSearchBaseGetBase(std::vector<Neighbor> &search_pool, const float *query, uint32_t tgt_base,
                                              const Parameters &parameters, std::vector<uint32_t> &pruned_list) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    std::vector<Neighbor> base_pool;
    std::unordered_map<uint32_t, uint32_t> base_id;
    for (auto &b_node : search_pool) {
        if (base_id.find(b_node.id) == base_id.end()) {
            if (b_node.id == tgt_base) {
                continue;
            }
            base_pool.push_back(b_node);
            base_id[b_node.id] = 1;
        }
    }

    std::sort(base_pool.begin(), base_pool.end());
    std::vector<uint32_t> result;
    result.reserve(M_pjbp * PROJECTION_SLACK);
    uint32_t start = 0;
    result.push_back(base_pool[start].id);

    while (result.size() < M_pjbp && (++start) < base_pool.size()) {
        Neighbor &p = base_pool[start];
        bool occlude = false;
        for (size_t t = 0; t < result.size(); ++t) {
            if (p.id == result[t]) {
                occlude = true;
                break;
            }
            float djk = l2_distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * result[t], dimension_);
            if (1 * djk < p.distance) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p.id != tgt_base) {
                result.push_back(p.id);
            }
        }
    }
    for (size_t i = 1; i < base_pool.size() && result.size() < M_pjbp; ++i) {
        if (std::find(result.begin(), result.end(), base_pool[i].id) == result.end()) {
            if (base_pool[i].id != tgt_base) {
                result.push_back(base_pool[i].id);
            }
        }
    }

    pruned_list = result;
}

uint32_t IndexRetrAtten::PruneProjectionRetrAttenCandidates(std::vector<Neighbor> &search_pool, const float *query,
                                                            uint32_t qid, const Parameters &parameters,
                                                            std::vector<uint32_t> &pruned_list) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");

    uint32_t exp_candidate_len =
        search_pool.size() + bipartite_graph_[u32_nd_ + qid].size() + projection_graph_[qid].size();
    
    std::sort(search_pool.begin(), search_pool.end());
    std::vector<uint32_t> candidate_only_id;
    candidate_only_id.reserve(exp_candidate_len);
    for (size_t i = 0; i < search_pool.size(); ++i) {
        candidate_only_id.push_back(search_pool[i].id);
    }
    uint32_t src_node = candidate_only_id[0];
    std::vector<uint32_t> result;
    result.reserve(M_pjbp * PROJECTION_SLACK);
    uint32_t start = 1;  
    result.push_back(candidate_only_id[start]);

    while (result.size() < M_pjbp && (++start) < candidate_only_id.size()) {
        auto &p = candidate_only_id[start];
        float dik = distance_->compare(data_bp_ + dimension_ * p, data_bp_ + dimension_ * src_node, dimension_);
        bool occlude = false;
        for (size_t t = 0; t < result.size(); ++t) {
            if (p == result[t]) {
                occlude = true;
                break;
            }
            float djk = distance_->compare(data_bp_ + dimension_ * p, data_bp_ + dimension_ * result[t], dimension_);
            if (djk < dik) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p != src_node) {
                result.push_back(p);
            }
        }
    }
    for (size_t i = 1; i < candidate_only_id.size() && result.size() < M_pjbp; ++i) {
        if (std::find(result.begin(), result.end(), candidate_only_id[i]) == result.end()) {
            if (candidate_only_id[i] != src_node) {
                result.push_back(candidate_only_id[i]);
            }
        }
    }

    pruned_list = result;
    return src_node;
}

uint32_t IndexRetrAtten::PruneProjectionCandidates(std::vector<Neighbor> &search_pool, const float *query, uint32_t qid,
                                                   const Parameters &parameters, std::vector<uint32_t> &pruned_list) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    
    std::vector<uint32_t> candidate_only_id;

    std::sort(search_pool.begin(), search_pool.end());
    for (size_t i = 0; i < search_pool.size(); ++i) {
        candidate_only_id.push_back(search_pool[i].id);
    }
    NeighborPriorityQueue pruned_pool(candidate_only_id.size());
    pruned_pool.reserve(candidate_only_id.size());

    for (size_t i = 0; i < search_pool.size(); ++i) {
        pruned_pool.insert(search_pool[i]);
    }

    for (size_t i = search_pool.size(); i < candidate_only_id.size(); ++i) {
        pruned_pool.insert({candidate_only_id[i], 0.1, false});
    }

    uint32_t src_node = pruned_pool[0].id;  
    std::vector<uint32_t> result;
    result.reserve(M_pjbp * PROJECTION_SLACK);
    uint32_t start = 1;  
    result.push_back(pruned_pool[start].id);

    while (result.size() < M_pjbp && (++start) < pruned_pool.size()) {
        Neighbor &p = pruned_pool[start];
        bool occlude = false;
        float dik = distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * src_node, dimension_);
        for (size_t t = 0; t < result.size(); ++t) {
            float djk = distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * result[t], dimension_);
            if (djk < dik) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p.id != src_node) {
                result.push_back(p.id);
            }
        }
    }
    for (size_t i = 1; i < candidate_only_id.size() && result.size() < M_pjbp; ++i) {
        if (std::find(result.begin(), result.end(), candidate_only_id[i]) == result.end()) {
            if (candidate_only_id[i] != src_node) {
                result.push_back(candidate_only_id[i]);
            }
        }
    }

    pruned_list = result;
    return src_node;
}

void IndexRetrAtten::PruneProjectionBaseSearchCandidates(std::vector<Neighbor> &search_pool, const float *query,
                                                         uint32_t qid, const Parameters &parameters,
                                                         std::vector<uint32_t> &pruned_list) {
    
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp") * 2;
    uint32_t degree = M_pjbp;
    std::vector<uint32_t> result;
    std::sort(search_pool.begin(), search_pool.end());

    result.reserve(M_pjbp * PROJECTION_SLACK);
    uint32_t start = 0;

    auto &src_nbrs = projection_graph_[qid];
    while (start < search_pool.size() && ((std::find(src_nbrs.begin(), src_nbrs.end(), search_pool[start].id) != src_nbrs.end()) || search_pool[start].id == qid)) {
        ++start;
    }
    if (start >= search_pool.size()) {
        return;
    }
    result.push_back(search_pool[start].id);
    
    while (result.size() < M_pjbp && (++start) < search_pool.size()) {
        Neighbor &p = search_pool[start];
        bool occlude = false;
        for (size_t t = 0; t < result.size(); ++t) {
            if (p.id == result[t]) {
                occlude = true;
                break;
            }

            if (!check_valid_range(p.id)) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
            }

            if (!check_valid_range(result[t])) {
                std::cout << "In file " << __FILE__ << " at line " << __LINE__ << "out of range." << std::endl << std::flush;
            }
            float djk = l2_distance_->compare(data_bp_ + dimension_ * p.id, data_bp_ + dimension_ * result[t], dimension_);
            if (1 * djk < p.distance) {
                occlude = true;
                break;
            }
        }
        if (!occlude) {
            if (p.id != qid) {
                
                result.push_back(p.id);
                
            }
        }
    }
    pruned_list = result;
}

void IndexRetrAtten::SearchProjectionbyQuery(const float *query, const Parameters &parameters,
                                             NeighborPriorityQueue &search_pool, boost::dynamic_bitset<> &visited,
                                             std::vector<Neighbor> &full_retset) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint32_t> dis(0, u32_nd_ - 1);

    std::vector<uint32_t> init_ids(search_pool.capacity() + 1);
    init_ids[0] = projection_ep_;
    
    for (size_t i = 1; i < (init_ids.size() - 1) && (i - 1) < projection_graph_[projection_ep_].size(); ++i) {
        init_ids[i] = projection_graph_[projection_ep_][i - 1];
        
        visited.set(init_ids[i]);
    }
    init_ids.push_back(dis(gen));
    while (init_ids.size() < search_pool.capacity()) {
        uint32_t rand_id = dis(gen);
        if (!visited.test(rand_id)) {
            init_ids.push_back(rand_id);
            visited.set(rand_id);
        }
    }

    for (size_t i = 0; i < init_ids.size(); ++i) {
        uint32_t id = init_ids[i];
        float distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn(id, distance, false);
        search_pool.insert(nn);
        full_retset.push_back(nn);
    }

    while (search_pool.has_unexpanded_node()) {
        auto cur_check = search_pool.closest_unexpanded();
        auto cur_id = cur_check.id;

        for (size_t j = 0; j < projection_graph_[cur_id].size(); ++j) {
            uint32_t id = projection_graph_[cur_id][j];
            if (!visited.test(id)) {
                float distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
                Neighbor nn(id, distance, false);
                search_pool.insert(nn);
                visited.set(id);
                full_retset.push_back(nn);
            }
        }
    }
}

void IndexRetrAtten::CalculateProjectionep() {
    float *center = new float[dimension_]();
    memset(center, 0, sizeof(float) * dimension_);
    
    for (size_t i = 0; i < nd_; ++i) {
        for (size_t d = 0; d < dimension_; ++d) {
            center[d] += data_bp_[i * dimension_ + d];
        }
    }

    for (size_t d = 0; d < dimension_; ++d) {
        center[d] /= (float)nd_;
    }

    float *distances = new float[nd_]();
    memset(distances, 0, sizeof(float) * nd_);
#pragma omp parallel for
    for (size_t i = 0; i < nd_; ++i) {
        const float *cur_data = data_bp_ + i * dimension_;
        float diff = 0;
        for (size_t j = 0; j < dimension_; ++j) {
            diff += ((center[j] - cur_data[j]) * (center[j] - cur_data[j]));
        }
        distances[i] = diff;
    }

    uint32_t closest = 0;
    for (size_t i = 1; i < nd_; ++i) {
        if (distances[i] < distances[closest]) {
            closest = static_cast<uint32_t>(i);
        }
    }
    projection_ep_ = closest;
    delete[] center;
    delete[] distances;
}

void IndexRetrAtten::Build(size_t n, const float *data, const Parameters &parameters){};

void IndexRetrAtten::Save(const char *filename) {
    
    std::ofstream out(filename, std::ios::binary | std::ios::out);
    uint32_t npts = static_cast<uint32_t>(total_pts_);
    out.write((char *)&npts, sizeof(npts));
    for (uint32_t i = 0; i < total_pts_; i++) {
        uint32_t nbr_size = static_cast<uint32_t>(bipartite_graph_[i].size());
        out.write((char *)&nbr_size, sizeof(nbr_size));
        out.write((char *)bipartite_graph_[i].data(), nbr_size * sizeof(uint32_t));
    }
    out.close();
}

void IndexRetrAtten::Load(const char *filename) {
    
    std::ifstream in(filename, std::ios::binary);
    uint32_t npts;
    in.read((char *)&npts, sizeof(npts));
    bipartite_graph_.resize(npts);
    for (uint32_t i = 0; i < npts; i++) {
        uint32_t nbr_size;
        in.read((char *)&nbr_size, sizeof(nbr_size));
        bipartite_graph_[i].resize(nbr_size);
        in.read((char *)bipartite_graph_[i].data(), nbr_size * sizeof(uint32_t));
    }
    in.close();
}

void IndexRetrAtten::LoadNsgGraph(const char *filename) {
    
    std::ifstream in(filename, std::ios::binary);
    uint32_t width = 0;
    in.read((char *)&width, sizeof(width));
    uint32_t npts = 1000000;
    in.read((char *)&projection_ep_, sizeof(uint32_t));
    std::cout << "Projection graph, "
              << "ep: " << projection_ep_ << std::endl;
    
    projection_graph_.resize(npts);
    float out_degree = 0.0;
    for (uint32_t i = 0; i < npts; i++) {
        uint32_t nbr_size;
        in.read((char *)&nbr_size, sizeof(nbr_size));
        out_degree += static_cast<float>(nbr_size);
        projection_graph_[i].resize(nbr_size);
        in.read((char *)projection_graph_[i].data(), nbr_size * sizeof(uint32_t));
    }
    std::cout << "Projection graph, "
              << "avg_degree: " << out_degree / npts << std::endl;
    in.close();
}

void IndexRetrAtten::LoadProjectionGraph(const char *filename) {
    
    std::ifstream in(filename, std::ios::binary);
    uint32_t npts;
    in.read((char *)&projection_ep_, sizeof(uint32_t));
    in.read((char *)&npts, sizeof(npts));
    projection_graph_.resize(npts);
    float out_degree = 0.0;
    for (uint32_t i = 0; i < npts; i++) {
        uint32_t nbr_size;
        in.read((char *)&nbr_size, sizeof(nbr_size));
        out_degree += static_cast<float>(nbr_size);
        projection_graph_[i].resize(nbr_size);
        in.read((char *)projection_graph_[i].data(), nbr_size * sizeof(uint32_t));
    }
    in.close();
}

uint32_t IndexRetrAtten::SearchRetrAttenGraph(const float *query, size_t k, size_t &qid, const Parameters &parameters,
                                              unsigned *indices) {
    
    uint32_t L_pq = parameters.Get<uint32_t>("L_pq");
    NeighborPriorityQueue search_queue(L_pq);
    search_queue.reserve(L_pq);
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint32_t> dis(0, u32_nd_ - 1);
    uint32_t start = dis(gen);  
    std::vector<uint32_t> init_ids;
    init_ids.push_back(start);
    
    boost::dynamic_bitset<> visited{total_pts_, 0};
    for (auto &id : init_ids) {
        float distance;
        
        distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
        visited.set(id);
        
    }

    uint32_t cmps = 0;
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        uint32_t first_hop_rank_1 = bipartite_graph_[cur_id][0];
        float first_hop_min_dist = 1000;
        
        for (auto nbr : bipartite_graph_[cur_id]) {  

            for (auto ns_nbr : bipartite_graph_[nbr]) {  
                
                if (visited.test(ns_nbr)) {
                    continue;
                }
                visited.set(ns_nbr);
                
                float distance;
                
                distance = distance_->compare(data_bp_ + ns_nbr * dimension_, query, (unsigned)dimension_);
                if (distance < first_hop_min_dist) {
                    
                    first_hop_min_dist = distance;
                    first_hop_rank_1 = nbr;
                }
                ++cmps;
                Neighbor nn = Neighbor(ns_nbr, distance, false);
                search_queue.insert(nn);
                
                break;  
            }
        }

        for (auto &ns_nbr : bipartite_graph_[first_hop_rank_1]) {
            if (visited.test(ns_nbr)) {
                continue;
            }
            visited.set(ns_nbr);
            
            float distance;
            
            distance = distance_->compare(data_bp_ + ns_nbr * dimension_, query, (unsigned)dimension_);
            ++cmps;
            Neighbor nn = Neighbor(ns_nbr, distance, false);
            search_queue.insert(nn);
            
        }
    }

    if (search_queue.size() < k) {
        std::stringstream ss;
        ss << "not enough results: " << search_queue.size() << ", expected: " << k;
        throw std::runtime_error(ss.str());
    }
    
    for (size_t i = 0; i < k; ++i) {
        indices[qid * k + i] = search_queue[i].id;
    }
    return cmps;
}
void IndexRetrAtten::Search(const float *query, const float *x, size_t k, const Parameters &parameters,
                            unsigned *indices, float* res_dists) {
    
    uint32_t L_pq = parameters.Get<uint32_t>("L_pq");
    NeighborPriorityQueue search_queue(L_pq);
    search_queue.reserve(L_pq);
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<uint32_t> dis(0, u32_nd_ - 1);
    uint32_t start = dis(gen);  
    std::vector<uint32_t> init_ids;
    init_ids.push_back(start);
    
    boost::dynamic_bitset<> visited{total_pts_, 0};
    for (auto &id : init_ids) {
        float distance;
        
        distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
        visited.set(id);
        
    }

    uint32_t cmps = 0;
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        uint32_t first_hop_rank_1 = bipartite_graph_[cur_id][0];
        float first_hop_min_dist = 1000;
        
        for (auto &nbr : bipartite_graph_[cur_id]) {  

            for (auto &ns_nbr : bipartite_graph_[nbr]) {  
                
                if (visited.test(ns_nbr)) {
                    continue;
                }
                visited.set(ns_nbr);
                
                float distance;
                
                distance = distance_->compare(data_bp_ + ns_nbr * dimension_, query, (unsigned)dimension_);
                if (first_hop_min_dist > distance) {
                    
                    first_hop_min_dist = distance;
                    first_hop_rank_1 = nbr;
                }
                ++cmps;
                Neighbor nn = Neighbor(ns_nbr, distance, false);
                search_queue.insert(nn);
            }
        }
    }

    if (unlikely(search_queue.size() < k)) {
        std::stringstream ss;
        ss << "not enough results: " << search_queue.size() << ", expected: " << k;
        throw std::runtime_error(ss.str());
    }
    
    for (size_t i = 0; i < k; ++i) {
        indices[i] = search_queue[i].id;
    }
}

std::pair<uint32_t, uint32_t> IndexRetrAtten::SearchRAIndex(const float *query, size_t k, size_t &qid, const Parameters &parameters,
                                               unsigned *indices, std::vector<float>& res_dists) {
    uint32_t L_pq = parameters.Get<uint32_t>("L_pq");
    NeighborPriorityQueue search_queue(L_pq);
    std::vector<uint32_t> init_ids;
    init_ids.push_back(projection_ep_);
    prefetch_vector((char *)(data_bp_ + projection_ep_ * dimension_), dimension_);
    VisitedList *vl = visited_list_pool_->getFreeVisitedList();
    vl_type *visited_array = vl->mass;
    vl_type visited_array_tag = vl->curV;

    for (auto &id : init_ids) {
        float distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
    }
    uint32_t cmps = 0;
    uint32_t hops = 0;
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        
        uint32_t *cur_nbrs = projection_graph_[cur_id].data();
        ++hops;
        
        for (size_t j = 0; j < projection_graph_[cur_id].size(); ++j) {  
            uint32_t nbr = *(cur_nbrs + j);
            _mm_prefetch((char *)(visited_array + *(cur_nbrs + j + 1)), _MM_HINT_T0);
            _mm_prefetch((char *)(data_bp_ + *(cur_nbrs + j + 1) * dimension_), _MM_HINT_T0);
            if (visited_array[nbr] != visited_array_tag) {
                visited_array[nbr] = visited_array_tag;
                
                float distance = distance_->compare(data_bp_ + nbr * dimension_, query, (unsigned)dimension_);
                ++cmps;
                search_queue.insert({nbr, distance, false});
            }
        }
    }
    visited_list_pool_->releaseVisitedList(vl);

    if (unlikely(search_queue.size() < k)) {
        std::stringstream ss;
        ss << "not enough results: " << search_queue.size() << ", expected: " << k;
        throw std::runtime_error(ss.str());
    }

    for (size_t i = 0; i < k; ++i) {
        
        indices[i] = search_queue[i].id;
        res_dists[i] = search_queue[i].distance;
    }
    return std::make_pair(cmps, hops);
}

uint32_t IndexRetrAtten::SearchRAIndexPy(const float *query, size_t k, size_t &qid, uint32_t L_pq,
                                               uint32_t *indices, float* res_dists) {
    
    NeighborPriorityQueue search_queue(L_pq);
    std::vector<uint32_t> init_ids;
    init_ids.push_back(projection_ep_);
    if (likely(projection_ep_ < u32_nd_ - 1)) {
        init_ids.push_back(projection_ep_ + 1);
    }
    if (likely(projection_ep_ > 0)) {
        init_ids.push_back(projection_ep_ - 1);
    }
    prefetch_vector((char *)(data_bp_ + projection_ep_ * dimension_), dimension_);
    VisitedList *vl = visited_list_pool_->getFreeVisitedList();
    vl_type *visited_array = vl->mass;
    vl_type visited_array_tag = vl->curV;

    for (auto &id : init_ids) {
        float distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
    }
    uint32_t cmps = 0;
    
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        
        uint32_t *cur_nbrs = projection_graph_[cur_id].data();
        for (size_t j = 0; j < projection_graph_[cur_id].size(); ++j) {  
            uint32_t nbr = *(cur_nbrs + j);
            _mm_prefetch((char *)(visited_array + *(cur_nbrs + j + 1)), _MM_HINT_T0);
            _mm_prefetch((char *)(data_bp_ + *(cur_nbrs + j + 1) * dimension_), _MM_HINT_T0);
            if (visited_array[nbr] != visited_array_tag) {
                visited_array[nbr] = visited_array_tag;
                
                float distance = distance_->compare(data_bp_ + nbr * dimension_, query, (unsigned)dimension_);
                ++cmps;
                search_queue.insert({nbr, distance, false});
            }
        }
    }
    visited_list_pool_->releaseVisitedList(vl);

    if (unlikely(search_queue.size() < k)) {
        std::cout << "not enough results: " << search_queue.size() << ", expected: " << k << std::endl;
        size_t i = 0;
        for (; i < search_queue.size(); ++i) {
            
            if (i >= k) {
                break;
            }
            indices[i] = search_queue[i].id;
            res_dists[i] = -1.0 * search_queue[i].distance;
        }
        std::random_device rd;
        std::mt19937 gen(rd());
        std::uniform_int_distribution<uint32_t> dis(0, u32_nd_ - 1);
        while (i < k) {
            indices[i] = dis(gen);
            res_dists[i] = 1;
            ++i;
        }
    }

    for (size_t i = 0; i < search_queue.size(); ++i) {
        
        if (i >= k) {
            break;
        }
        indices[i] = search_queue[i].id;
        res_dists[i] = -1.0 * search_queue[i].distance;
    }
    
    return cmps;
}

uint32_t IndexRetrAtten::SearchRAIndexIterativelyPy(const float *query, size_t& end_k, unsigned *indices, float* res_dists) {
    NeighborPriorityQueue* search_queue = iterative_search_state->search_pool;
    if (search_queue == nullptr) {
        throw std::runtime_error("search queue is null");
    }
    if (!iterative_search_state->init_search_done) {
        iterative_search_state->get_visited_list(visited_list_pool_);
    }
    VisitedList *vl = iterative_search_state->vl;
    vl_type *visited_array = vl->mass;
    vl_type visited_array_tag = vl->curV;

    if (!iterative_search_state->init_search_done) {
        if (search_queue->size() != 0) {
            throw std::runtime_error("search queue is not empty");
        }

        std::vector<uint32_t> init_ids;

        init_ids.push_back(projection_ep_);
        if (likely(projection_ep_ < u32_nd_ - 1)) {
            init_ids.push_back(projection_ep_ + 1);
        }
        if (likely(projection_ep_ > 0)) {
            init_ids.push_back(projection_ep_ - 1);
        }
        prefetch_vector((char *)(data_bp_ + projection_ep_ * dimension_), dimension_);
        for (auto &id : init_ids) {
            float distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
            Neighbor nn = Neighbor(id, distance, false);
            search_queue->insert(nn);
        }
    }

    uint32_t cmps = 0;
    uint32_t hops = 0;
    size_t max_closest_unexpanded_pos = 0;
    size_t closest_unexpanded_pos = 0;
    size_t return_k = std::min(end_k, search_queue->capacity());
    while (search_queue->has_unexpanded_node()) {
        
        auto cur_check_node = search_queue->closest_unexpanded_with_pos(closest_unexpanded_pos);
        max_closest_unexpanded_pos = std::max(max_closest_unexpanded_pos, closest_unexpanded_pos);

        auto cur_id = cur_check_node.id;
        
        uint32_t *cur_nbrs = projection_graph_[cur_id].data();
        ++hops;
        
        for (size_t j = 0; j < projection_graph_[cur_id].size(); ++j) {  
            uint32_t nbr = *(cur_nbrs + j);
            _mm_prefetch((char *)(visited_array + *(cur_nbrs + j + 1)), _MM_HINT_T0);
            _mm_prefetch((char *)(data_bp_ + *(cur_nbrs + j + 1) * dimension_), _MM_HINT_T0);
            if (visited_array[nbr] != visited_array_tag) {
                visited_array[nbr] = visited_array_tag;
                
                float distance = distance_->compare(data_bp_ + nbr * dimension_, query, (unsigned)dimension_);
                ++cmps;
                search_queue->insert({nbr, distance, false});
            }
        }
        iterative_search_state->init_search_done = true;
        if (!search_queue->has_unexpanded_node_before_n(return_k)) {
            break;
        }
    }
    return_k = std::min(end_k, search_queue->size());
    for (size_t i = 0; i < return_k; ++i) {
        
        indices[i] = search_queue->get_data()[i].id;
        res_dists[i] = -1.0 * search_queue->get_data()[i].distance;
    }
    end_k = return_k;
    return cmps;
}

uint32_t IndexRetrAtten::SearchRAIndexPyReturnFullVisitedSet(const float *query, size_t k, size_t &qid, uint32_t L_pq,
                                               uint32_t *all_visited, float *all_visited_dists) {
    
    NeighborPriorityQueue search_queue(L_pq);
    std::vector<uint32_t> init_ids;
    init_ids.push_back(projection_ep_);

    uint32_t visited_idx = 0;
    if (likely(projection_ep_ < u32_nd_ - 1)) {
        init_ids.push_back(projection_ep_ + 1);
    }
    if (likely(projection_ep_ > 0)) {
        init_ids.push_back(projection_ep_ - 1);
    }
    prefetch_vector((char *)(data_bp_ + projection_ep_ * dimension_), dimension_);
    VisitedList *vl = visited_list_pool_->getFreeVisitedList();
    vl_type *visited_array = vl->mass;
    vl_type visited_array_tag = vl->curV;

    for (auto &id : init_ids) {
        float distance = distance_->compare(data_bp_ + id * dimension_, query, (unsigned)dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
    }
    uint32_t cmps = 0;
    
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        
        uint32_t *cur_nbrs = projection_graph_[cur_id].data();
        for (size_t j = 0; j < projection_graph_[cur_id].size(); ++j) {  
            uint32_t nbr = *(cur_nbrs + j);
            _mm_prefetch((char *)(visited_array + *(cur_nbrs + j + 1)), _MM_HINT_T0);
            _mm_prefetch((char *)(data_bp_ + *(cur_nbrs + j + 1) * dimension_), _MM_HINT_T0);
            if (visited_array[nbr] != visited_array_tag) {
                visited_array[nbr] = visited_array_tag;
                
                float distance = distance_->compare(data_bp_ + nbr * dimension_, query, (unsigned)dimension_);

                all_visited[visited_idx] = nbr;
                all_visited_dists[visited_idx] = distance;
                ++visited_idx;
                ++cmps;
                search_queue.insert({nbr, distance, false});
            }
        }
    }
    visited_list_pool_->releaseVisitedList(vl);

    if (visited_idx < k) {
        std::stringstream ss;
        ss << "not enough results: " << visited_idx << ", expected: " << k;
        throw std::runtime_error(ss.str());
    }
    return cmps;
}

uint32_t IndexRetrAtten::SearchRAIndexPySQ(const float *query, size_t k, size_t &qid, uint32_t L_pq,
                                               unsigned *indices, float* res_dists) {
    if (unlikely(!has_sq)) {
        throw std::runtime_error("scalar quantizer is not set");
    }
    
    NeighborPriorityQueue search_queue(L_pq);
    std::vector<uint32_t> init_ids;
    init_ids.push_back(projection_ep_);
    if (likely(projection_ep_ < u32_nd_ - 1)) {
        init_ids.push_back(projection_ep_ + 1);
    }
    if (likely(projection_ep_ > 0)) {
        init_ids.push_back(projection_ep_ - 1);
    }
    
    sqdc_->set_query(query);
    
    prefetch_vector((char *)(scalar_quant_->codes.data() + projection_ep_ * dimension_), dimension_);
    VisitedList *vl = visited_list_pool_->getFreeVisitedList();
    vl_type *visited_array = vl->mass;
    vl_type visited_array_tag = vl->curV;

    for (auto &id : init_ids) {
        float distance = -1.0 * sqdc_->distance_to_code(scalar_quant_->codes.data() + id * dimension_);
        Neighbor nn = Neighbor(id, distance, false);
        search_queue.insert(nn);
    }
    uint32_t cmps = 0;
    
    while (search_queue.has_unexpanded_node()) {
        
        auto cur_check_node = search_queue.closest_unexpanded();
        auto cur_id = cur_check_node.id;
        
        uint32_t *cur_nbrs = projection_graph_[cur_id].data();
        for (size_t j = 0; j < projection_graph_[cur_id].size(); ++j) {  
            uint32_t nbr = *(cur_nbrs + j);
            _mm_prefetch((char *)(visited_array + *(cur_nbrs + j + 1)), _MM_HINT_T0);
            _mm_prefetch((char *)(scalar_quant_->codes.data() + *(cur_nbrs + j + 1) * dimension_), _MM_HINT_T0);
            if (visited_array[nbr] != visited_array_tag) {
                visited_array[nbr] = visited_array_tag;
                float distance = -1.0 * sqdc_->distance_to_code(scalar_quant_->codes.data() + nbr * dimension_);
                ++cmps;
                search_queue.insert({nbr, distance, false});
            }
        }
    }
    visited_list_pool_->releaseVisitedList(vl);

    if (unlikely(search_queue.size() < k)) {
        std::stringstream ss;
        ss << "not enough results: " << search_queue.size() << ", expected: " << k;
        throw std::runtime_error(ss.str());
    }

    for (size_t i = 0; i < k; ++i) {
        
        indices[i] = search_queue[i].id;
        res_dists[i] = -1.0 * search_queue[i].distance;
    }
    
    return cmps;
}
void IndexRetrAtten::findroot(boost::dynamic_bitset<> &flag, unsigned &root, const Parameters &parameters) {
    unsigned id = nd_;
    for (unsigned i = 0; i < nd_; i++) {
        if (flag[i] == false) {
            id = i;
            break;
        }
    }

    if (id == nd_) return;  

    std::vector<Neighbor> tmp, pool;
    NeighborPriorityQueue temp_queue;
    SearchProjectionGraphInternal(temp_queue, data_bp_ + dimension_ * id, id, parameters, flag, pool);
    
    std::sort(pool.begin(), pool.end());

    unsigned found = 0;
    for (unsigned i = 0; i < pool.size(); i++) {
        if (flag[pool[i].id]) {
            
            root = pool[i].id;
            found = 1;
            break;
        }
    }
    if (found == 0) {
        while (true) {
            unsigned rid = rand() % nd_;
            if (flag[rid]) {
                root = rid;
                break;
            }
        }
    }
    projection_graph_[root].push_back(id);
}

void IndexRetrAtten::dfs(boost::dynamic_bitset<> &flag, unsigned root, unsigned &cnt) {
    unsigned tmp = root;
    std::stack<unsigned> s;
    s.push(root);
    if (!flag[root]) cnt++;
    flag[root] = true;
    while (!s.empty()) {
        unsigned next = nd_ + 1;
        for (unsigned i = 0; i < projection_graph_[tmp].size(); i++) {
            if (flag.test(projection_graph_[tmp][i]) == false) {
                next = projection_graph_[tmp][i];
                break;
            }
        }
        
        if (next == (nd_ + 1)) {
            s.pop();
            if (s.empty()) break;
            tmp = s.top();
            continue;
        }
        tmp = next;
        flag[tmp] = true;
        s.push(tmp);
        cnt++;
    }
}

void IndexRetrAtten::CollectPoints(const Parameters &parameters) {
    unsigned root = projection_ep_;
    boost::dynamic_bitset<> flags{nd_, 0};
    unsigned unlinked_cnt = 0;
    while (unlinked_cnt < nd_) {
        dfs(flags, root, unlinked_cnt);
        
        if (unlinked_cnt >= nd_) break;
        findroot(flags, root, parameters);
    }
}

void IndexRetrAtten::SaveBaseData(const char *filename) {
    std::ofstream out(filename, std::ios::binary | std::ios::out);
    if (!out.is_open()) {
        throw std::runtime_error("cannot open file");
    }
    uint32_t dim = dimension_;
    out.write((char *)&u32_nd_, sizeof(uint32_t));
    out.write((char *)&dim, sizeof(uint32_t));
    out.write((char *)data_bp_, sizeof(float) * dim * nd_);
    out.close();
}

void IndexRetrAtten::LoadBaseData(const char *filename) {
    std::ifstream in(filename, std::ios::binary);
    if (!in.is_open()) {
        throw std::runtime_error("cannot open file");
    }
    uint32_t dim;
    in.read((char *)&u32_nd_, sizeof(uint32_t));
    in.read((char *)&dim, sizeof(uint32_t));
    data_bp_ = new float[dim * nd_];
    in.read((char *)data_bp_, sizeof(float) * dim * nd_);
    
    in.close();
}

void IndexRetrAtten::SaveProjectionGraph(const char *filename) {
    std::ofstream out(filename, std::ios::binary | std::ios::out);
    if (!out.is_open()) {
        throw std::runtime_error("cannot open file");
    }
    out.write((char *)&projection_ep_, sizeof(uint32_t));
    out.write((char *)&u32_nd_, sizeof(uint32_t));
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        uint32_t nbr_size = projection_graph_[i].size();
        out.write((char *)&nbr_size, sizeof(uint32_t));
        out.write((char *)projection_graph_[i].data(), sizeof(uint32_t) * nbr_size);
    }
    out.close();
}

void IndexRetrAtten::SaveLayerQIndex(const char *filename) {
    std::ofstream out(filename, std::ios::binary | std::ios::out);
    if (!out.is_open()) {
        throw std::runtime_error("cannot open file");
    }
    out.write((char *)&projection_ep_, sizeof(uint32_t));
    out.write((char *)&u32_nd_, sizeof(uint32_t));
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        uint32_t nbr_size = projection_graph_[i].size();
        out.write((char *)&nbr_size, sizeof(uint32_t));
        out.write((char *)projection_graph_[i].data(), sizeof(uint32_t) * nbr_size);
    }
    out.close();

    std::ofstream data_out(std::string(filename) + ".data", std::ios::binary | std::ios::out);
    if (!data_out.is_open()) {
        throw std::runtime_error("cannot open file");
    }
    data_out.write((char *)data_bp_, sizeof(float) * dimension_ * u32_nd_);
    data_out.close();
}

void IndexRetrAtten::LoadLayerQIndex(const char *filename, const float *data_bp) {
    std::ifstream in(filename, std::ios::binary);
    if (!in.is_open()) {
        throw std::runtime_error("cannot open file");
    }
    in.read((char *)&projection_ep_, sizeof(uint32_t));
    in.read((char *)&u32_nd_, sizeof(uint32_t));
    nd_ = u32_nd_;
    total_pts_ = u32_nd_;
    u32_total_pts_ = u32_nd_;
    projection_graph_.resize(u32_nd_);
    for (uint32_t i = 0; i < u32_nd_; ++i) {
        uint32_t nbr_size;
        in.read((char *)&nbr_size, sizeof(uint32_t));
        projection_graph_[i].resize(nbr_size);
        in.read((char *)projection_graph_[i].data(), sizeof(uint32_t) * nbr_size);
    }
    in.close();
    data_bp_ = data_bp;
}
void IndexRetrAtten::LoadLearnBaseKNN(const char *filename) {
    std::ifstream in(filename, std::ios::binary);
    uint32_t npts;
    uint32_t k_dim;
    in.read((char *)&npts, sizeof(npts));
    in.read((char *)&(k_dim), sizeof(k_dim));
    std::cout << "learn base knn npts: " << npts << ", k_dim: " << k_dim << std::endl;

    learn_base_knn_.resize(npts);
    for (uint32_t i = 0; i < npts; i++) {
        learn_base_knn_[i].resize(k_dim);
        in.read((char *)learn_base_knn_[i].data(), sizeof(uint32_t) * k_dim);
    }
    if (learn_base_knn_.back().size() != k_dim) {
        throw std::runtime_error("learn base knn file error");
    }
    in.close();
}

void IndexRetrAtten::SetLearnBaseKNNi64(const int64_t* learn_base_knn, uint32_t npts, uint32_t k_dim) {
    learn_base_knn_.resize(npts);
#pragma omp parallel for
    for (uint32_t i = 0; i < npts; i++) {
        learn_base_knn_[i].resize(k_dim);
        
        for (uint32_t j = 0; j < k_dim; j++) {

            learn_base_knn_[i][j] = static_cast<uint32_t>(learn_base_knn[i * k_dim + j]);
        }
    }
}

void IndexRetrAtten::getithLearnNN(uint32_t i, uint32_t *learn_nn) {
    for (uint32_t j = 0; j < learn_base_knn_[i].size(); j++) {
        learn_nn[j] = learn_base_knn_[i][j];
    }
}

void IndexRetrAtten::SetLearnBaseKNN(const uint32_t* learn_base_knn, uint32_t npts, uint32_t k_dim) {
    learn_base_knn_.resize(npts);
#pragma omp parallel for
    for (uint32_t i = 0; i < npts; i++) {
        learn_base_knn_[i].resize(k_dim);
        memcpy(learn_base_knn_[i].data(), learn_base_knn + i * k_dim, sizeof(uint32_t) * k_dim);
    }
}

void IndexRetrAtten::LoadBaseLearnKNN(const char *filename) {
    std::ifstream in(filename, std::ios::binary);
    uint32_t npts;
    uint32_t k_dim;
    in.read((char *)&npts, sizeof(npts));
    in.read((char *)&(k_dim), sizeof(k_dim));
    std::cout << "base learn knn npts: " << npts << ", k_dim: " << k_dim << std::endl;

    base_learn_knn_.resize(npts);
    for (uint32_t i = 0; i < npts; i++) {
        base_learn_knn_[i].resize(k_dim);
        in.read((char *)base_learn_knn_[i].data(), sizeof(uint32_t) * k_dim);
    }
    if (base_learn_knn_.back().size() != k_dim) {
        throw std::runtime_error("base learn knn file error");
    }
    in.close();
}
void IndexRetrAtten::LoadVectorData(const char *base_file, const char *sampled_query_file) {
    uint32_t base_num = 0, sq_num = 0, base_dim = 0, q_dim = 0;

    load_meta<float>(base_file, base_num, base_dim);
    if (strlen(sampled_query_file) != 0) {
        load_meta<float>(sampled_query_file, sq_num, q_dim);
        if (base_dim != q_dim) {
            throw std::runtime_error("base and query dimension mismatch");
        }
    }
    float *base_data = nullptr;
    float *sampled_query_data = nullptr;
    load_data<float>(base_file, base_num, base_dim, base_data);
    if (need_normalize) {
        std::cout << "Normalizing base data" << std::endl;
        for (size_t i = 0; i < base_num; ++i) {
            normalize<float>(base_data + i * (uint64_t)base_dim, (uint64_t)base_dim);
        }
    }

    data_bp_ = data_align(base_data, base_num, base_dim);
    nd_ = base_num;
    nd_sq_ = sq_num;
    total_pts_ = nd_ + nd_sq_;
    u32_nd_ = static_cast<uint32_t>(nd_);
    u32_nd_sq_ = static_cast<uint32_t>(nd_sq_);
    u32_total_pts_ = static_cast<uint32_t>(total_pts_);
}
void IndexRetrAtten::SimulateRAIndexBatchInsert(const Parameters &parameters, uint32_t& wait_to_add_num, const float *data, std::vector<uint32_t>& closest_q_ids) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    size_t exist_data_num = projection_graph_.size();
    total_pts_ = exist_data_num + wait_to_add_num; 
    u32_total_pts_ = total_pts_;
    u32_nd_ = exist_data_num;
    nd_ = (size_t)exist_data_num;
    locks_ = std::vector<std::mutex>(total_pts_);
    uint32_t begin_no = exist_data_num;

    float * combined_data = new float[(exist_data_num + wait_to_add_num) * dimension_];

    memcpy(combined_data, data_bp_, exist_data_num * dimension_ * sizeof(float));

    memcpy(combined_data + exist_data_num * dimension_, data, wait_to_add_num * dimension_ * sizeof(float));

    delete[] data_bp_;
    data_bp_ = combined_data;
    projection_graph_.reserve(total_pts_);
    projection_graph_.resize(total_pts_);

    for (size_t i = 0; i < total_pts_; ++i) {
        
        projection_graph_[i].reserve(M_pjbp * 2 * PROJECTION_SLACK);
    }

#pragma omp parallel for schedule(dynamic, 100)
    for(size_t i = 0; i < wait_to_add_num; ++i) {
        uint32_t cur_insert_pt_no = begin_no + i;
        uint32_t correspond_exist_query = closest_q_ids[i];
        auto correspond_base = learn_base_knn_[correspond_exist_query];
        std::remove(correspond_base.begin(), correspond_base.end(), cur_insert_pt_no);
        std::vector<Neighbor> wait_to_projection_pool;
        for (size_t j = 0; j < correspond_base.size(); ++j) {
            float distance = distance_->compare(data_bp_ + correspond_base[j] * (size_t) dimension_, data + (size_t) dimension_ * i, (unsigned) dimension_);
            wait_to_projection_pool.push_back(Neighbor(correspond_base[j], distance, false));
        }
        std::vector<uint32_t> pruned_list;
        std::sort(wait_to_projection_pool.begin(), wait_to_projection_pool.end());
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        PruneBiSearchBaseGetBase(wait_to_projection_pool, data + dimension_ * i, cur_insert_pt_no, parameters, pruned_list);
        {
            LockGuard guard(locks_[cur_insert_pt_no]);
            projection_graph_[cur_insert_pt_no] = pruned_list;
        }
        ProjectionAddReverse(cur_insert_pt_no, parameters);
    }

    uint32_t self_loop_count = 0;
    uint32_t insert_zero_degree = 0;
    float inserted_avg_degree = 0.0;
    float inserted_max_degree = 0.0;
    float inserted_min_degree = 0.0;

    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        if (std::find(projection_graph_[i].begin(), projection_graph_[i].end(), i) != projection_graph_[i].end()) {
            
            self_loop_count++;
            std::remove(projection_graph_[i].begin(), projection_graph_[i].end(), i);
        }
        if (i >= begin_no) {
            if (projection_graph_[i].size() == 0) {
                insert_zero_degree++;
                std::vector<uint32_t> zero_degree_nbrs;
                std::random_device rd;
                std::mt19937 gen(rd());
                std::uniform_int_distribution<uint32_t> dis(0, total_pts_ - 1);
                for (size_t j = 0; j < 10; ++j) {
                    uint32_t random_nbr = dis(gen);
                    zero_degree_nbrs.push_back(random_nbr);
                }
            }
        }
    }
#pragma omp parallel for schedule(static, 100)
    for (uint32_t i = begin_no; i < total_pts_; ++i) {
        ProjectionAddReverse(i, parameters);
    }
    supply_nbrs_.resize(projection_graph_.size());

    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        
        supply_nbrs_[i] = projection_graph_[i];
        supply_nbrs_[i].reserve(M_pjbp * 2 * PROJECTION_SLACK);
        
    }
    uint32_t L_pjpq = parameters.Get<uint32_t>("L_pjpq");
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = begin_no; i < total_pts_; ++i) {
        size_t node = i;
        boost::dynamic_bitset<> visited{total_pts_, false};
        std::vector<Neighbor> full_retset;
        full_retset.reserve(L_pjpq);
        NeighborPriorityQueue search_pool;
        SearchProjectionGraphInternal(search_pool, data_bp_ + dimension_ * node, node, parameters, visited,
                                      full_retset);
        std::vector<uint32_t> pruned_list;
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        for (unsigned j = 0; j < full_retset.size(); j++) {
            if (full_retset[j].id == (unsigned)node) {
                full_retset.erase(full_retset.begin() + j);
                j--;
            }
        }
        PruneProjectionBaseSearchCandidates(full_retset, data_bp_ + dimension_ * node, node, parameters, pruned_list);
        {
            LockGuard guard(locks_[node]);

            supply_nbrs_[node] = pruned_list;
        }
        SupplyAddReverse(node, parameters);
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = begin_no; i < total_pts_; ++i) {
        size_t node = i;
        if (supply_nbrs_[node].size() > M_pjbp) {
            std::vector<Neighbor> full_retset;
            tsl::robin_set<uint32_t> visited;
            for (size_t j = 0; j < supply_nbrs_[node].size(); ++j) {
                if (visited.find(supply_nbrs_[node][j]) != visited.end()) {
                    continue;
                }
                float distance = distance_->compare(data_bp_ + dimension_ * supply_nbrs_[node][j],
                                                    data_bp_ + dimension_ * node, dimension_);
                visited.insert(supply_nbrs_[node][j]);
                full_retset.push_back(Neighbor(supply_nbrs_[node][j], distance, false));
            }
            std::vector<uint32_t> prune_list;
            PruneProjectionBaseSearchCandidates(full_retset, data_bp_ + dimension_ * node, node, parameters,
                                                prune_list);
            {
                LockGuard guard(locks_[node]);
                supply_nbrs_[node].clear();
                supply_nbrs_[node] = prune_list;
            }
        }
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (size_t i = 0; i < total_pts_; ++i) {
        std::vector<uint32_t> ok_insert;
        ok_insert.reserve(M_pjbp);
        for (size_t j = 0; j < supply_nbrs_[i].size(); ++j) {
            if (ok_insert.size() >= M_pjbp) {
                break;
            }
            if (std::find(projection_graph_[i].begin(), projection_graph_[i].end(), supply_nbrs_[i][j]) ==
                projection_graph_[i].end()) {
                ok_insert.push_back(supply_nbrs_[i][j]);
            }
        }
        projection_graph_[i].insert(projection_graph_[i].end(), ok_insert.begin(), ok_insert.end());
    }
    u32_nd_ = u32_nd_ + wait_to_add_num;
    std::cout << "u32_nd_: " << u32_nd_ << std::endl;
}
void IndexRetrAtten::SimulateRAIndexInsertOneKey(const Parameters &parameters, uint32_t& wait_to_add_num, const float *data, std::vector<uint32_t>& closest_q_ids) {
    uint32_t M_pjbp = parameters.Get<uint32_t>("M_pjbp");
    size_t exist_data_num = projection_graph_.size();
    total_pts_ = exist_data_num + wait_to_add_num; 
    u32_total_pts_ = total_pts_;
    u32_nd_ = exist_data_num;
    nd_ = (size_t)exist_data_num;
    locks_ = std::vector<std::mutex>(total_pts_);
    uint32_t begin_no = exist_data_num;

    float * combined_data = new float[(exist_data_num + wait_to_add_num) * dimension_];

    memcpy(combined_data, data_bp_, exist_data_num * dimension_ * sizeof(float));

    memcpy(combined_data + exist_data_num * dimension_, data, wait_to_add_num * dimension_ * sizeof(float));

    delete[] data_bp_;
    data_bp_ = combined_data;
    projection_graph_.reserve(total_pts_);
    projection_graph_.resize(total_pts_);

    for (size_t i = 0; i < total_pts_; ++i) {
        
        projection_graph_[i].reserve(M_pjbp * 2 * PROJECTION_SLACK);
    }

#pragma omp parallel for schedule(dynamic, 100)
    for(size_t i = 0; i < wait_to_add_num; ++i) {
        uint32_t cur_insert_pt_no = begin_no + i;
        uint32_t correspond_exist_query = closest_q_ids[i];
        auto correspond_base = learn_base_knn_[correspond_exist_query];
        std::remove(correspond_base.begin(), correspond_base.end(), cur_insert_pt_no);
        std::vector<Neighbor> wait_to_projection_pool;
        for (size_t j = 0; j < correspond_base.size(); ++j) {
            float distance = distance_->compare(data_bp_ + correspond_base[j] * (size_t) dimension_, data + (size_t) dimension_ * i, (unsigned) dimension_);
            wait_to_projection_pool.push_back(Neighbor(correspond_base[j], distance, false));
        }
        std::vector<uint32_t> pruned_list;
        std::sort(wait_to_projection_pool.begin(), wait_to_projection_pool.end());
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        PruneBiSearchBaseGetBase(wait_to_projection_pool, data + dimension_ * i, cur_insert_pt_no, parameters, pruned_list);
        {
            LockGuard guard(locks_[cur_insert_pt_no]);
            projection_graph_[cur_insert_pt_no] = pruned_list;
        }
        ProjectionAddReverse(cur_insert_pt_no, parameters);
        if (i % 1000 == 0) {
        }
    }

    uint32_t self_loop_count = 0;
    uint32_t insert_zero_degree = 0;
    float inserted_avg_degree = 0.0;
    float inserted_max_degree = 0.0;
    float inserted_min_degree = 0.0;

    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        if (std::find(projection_graph_[i].begin(), projection_graph_[i].end(), i) != projection_graph_[i].end()) {
            
            self_loop_count++;
            std::remove(projection_graph_[i].begin(), projection_graph_[i].end(), i);
        }
        if (i >= begin_no) {
            if (projection_graph_[i].size() == 0) {
                insert_zero_degree++;
                std::vector<uint32_t> zero_degree_nbrs;
                std::random_device rd;
                std::mt19937 gen(rd());
                std::uniform_int_distribution<uint32_t> dis(0, total_pts_ - 1);
                for (size_t j = 0; j < 10; ++j) {
                    uint32_t random_nbr = dis(gen);
                    zero_degree_nbrs.push_back(random_nbr);
                }
            }
        }
    }
#pragma omp parallel for schedule(static, 100)
    for (uint32_t i = begin_no; i < total_pts_; ++i) {
        ProjectionAddReverse(i, parameters);
    }
    supply_nbrs_.resize(projection_graph_.size());

    for (size_t i = 0; i < projection_graph_.size(); ++i) {
        
        supply_nbrs_[i] = projection_graph_[i];
        supply_nbrs_[i].reserve(M_pjbp * 2 * PROJECTION_SLACK);
        
    }
 std::chrono::high_resolution_clock::time_point t1, t2;
    t1 = std::chrono::high_resolution_clock::now();
    uint32_t L_pjpq = parameters.Get<uint32_t>("L_pjpq");
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = begin_no; i < total_pts_; ++i) {
        size_t node = i;
        boost::dynamic_bitset<> visited{total_pts_, false};
        std::vector<Neighbor> full_retset;
        full_retset.reserve(L_pjpq);
        NeighborPriorityQueue search_pool;
        SearchProjectionGraphInternal(search_pool, data_bp_ + dimension_ * node, node, parameters, visited,
                                      full_retset);
        std::vector<uint32_t> pruned_list;
        pruned_list.reserve(M_pjbp * PROJECTION_SLACK);
        for (unsigned j = 0; j < full_retset.size(); j++) {
            if (full_retset[j].id == (unsigned)node) {
                full_retset.erase(full_retset.begin() + j);
                j--;
            }
        }
        PruneProjectionBaseSearchCandidates(full_retset, data_bp_ + dimension_ * node, node, parameters, pruned_list);
        {
            LockGuard guard(locks_[node]);

            supply_nbrs_[node] = pruned_list;
        }
        SupplyAddReverse(node, parameters);
        if (node % 1000 == 0) {
        }
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (uint32_t i = begin_no; i < total_pts_; ++i) {
        size_t node = i;
        if (supply_nbrs_[node].size() > M_pjbp) {
            std::vector<Neighbor> full_retset;
            tsl::robin_set<uint32_t> visited;
            for (size_t j = 0; j < supply_nbrs_[node].size(); ++j) {
                if (visited.find(supply_nbrs_[node][j]) != visited.end()) {
                    continue;
                }
                float distance = distance_->compare(data_bp_ + dimension_ * supply_nbrs_[node][j],
                                                    data_bp_ + dimension_ * node, dimension_);
                visited.insert(supply_nbrs_[node][j]);
                full_retset.push_back(Neighbor(supply_nbrs_[node][j], distance, false));
            }
            std::vector<uint32_t> prune_list;
            PruneProjectionBaseSearchCandidates(full_retset, data_bp_ + dimension_ * node, node, parameters,
                                                prune_list);
            {
                LockGuard guard(locks_[node]);
                supply_nbrs_[node].clear();
                supply_nbrs_[node] = prune_list;
            }
        }
    }
#pragma omp parallel for schedule(dynamic, 100)
    for (size_t i = 0; i < total_pts_; ++i) {
        std::vector<uint32_t> ok_insert;
        ok_insert.reserve(M_pjbp);
        for (size_t j = 0; j < supply_nbrs_[i].size(); ++j) {
            if (ok_insert.size() >= M_pjbp) {
                break;
            }
            if (std::find(projection_graph_[i].begin(), projection_graph_[i].end(), supply_nbrs_[i][j]) ==
                projection_graph_[i].end()) {
                ok_insert.push_back(supply_nbrs_[i][j]);
            }
        }
        projection_graph_[i].insert(projection_graph_[i].end(), ok_insert.begin(), ok_insert.end());
    }

    t2 = std::chrono::high_resolution_clock::now();
    auto connectivity_enhancement_time = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
    for (size_t i = begin_no; i < projection_graph_.size(); ++i) {
        inserted_avg_degree += projection_graph_[i].size();
        if (projection_graph_[i].size() > inserted_max_degree) {
            inserted_max_degree = projection_graph_[i].size();
        }
        if (projection_graph_[i].size() < inserted_min_degree) {
            inserted_min_degree = projection_graph_[i].size();
        }
    }

    inserted_avg_degree /= wait_to_add_num;
    u32_nd_ = u32_nd_ + wait_to_add_num;
}
}  