#include <boost/container/set.hpp>
#include <boost/dynamic_bitset.hpp>
#include <cassert>
#include <mutex>
#include <set>
#include <shared_mutex>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include <vector>
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/impl/ScalarQuantizer.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/Quantizer.h>

#include "efanna2e/index.h"
#include "efanna2e/neighbor.h"
#include "efanna2e/parameters.h"
#include "efanna2e/util.h"
#include "visited_list_pool.h"

// define likely unlikely
#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)

namespace efanna2e {
using LockGuard = std::lock_guard<std::mutex>;
using SharedLockGuard = std::lock_guard<std::shared_mutex>;


struct IterativeSearchState {
    NeighborPriorityQueue* search_pool = nullptr;
    VisitedList* vl = nullptr;
    size_t max_search_queue_capacity = 0;
    size_t max_search_cmps_budgets = 0;
    size_t check_interval_hops = 0;
    size_t init_search_l = 0;
    bool init_search_done = false;
    bool is_phase_2 = false;
    bool accumulated_cmps = 0;

    IterativeSearchState() = default;
    IterativeSearchState(size_t max_search_queue_capacity, size_t max_search_cmps_budgets, size_t check_interval_hops, size_t init_search_l):
        max_search_queue_capacity(max_search_queue_capacity), max_search_cmps_budgets(max_search_cmps_budgets), check_interval_hops(check_interval_hops), init_search_l(init_search_l) {
        if (search_pool == nullptr) {
            search_pool = new NeighborPriorityQueue(max_search_queue_capacity);
        }
    }

    ~IterativeSearchState() {
        if (search_pool) {
            delete search_pool;
            search_pool = nullptr;
        }
    }

    void get_visited_list(VisitedListPool* visited_list_pool) {
        vl = visited_list_pool->getFreeVisitedList();
    }

    void release_visited_list(VisitedListPool* visited_list_pool) {
        visited_list_pool->releaseVisitedList(vl);
        vl = nullptr;
    }


    void reset(VisitedListPool* visited_list_pool) {
        search_pool->clear();
        init_search_done = false;
        is_phase_2 = false;
        accumulated_cmps = 0;
        release_visited_list(visited_list_pool);
    }
};

class IndexRetrAtten : public Index {
    typedef std::vector<std::vector<uint32_t>> CompactGraph;

   public:
    explicit IndexRetrAtten(const size_t dimension, const size_t n, Metric m, Index *initializer);
    virtual ~IndexRetrAtten();
    virtual void Save(const char *filename) override;
    virtual void Load(const char *filename) override;
    virtual void Search(const float *query, const float *x, size_t k, const Parameters &parameters,
                        unsigned *indices, float *res_dists) override;

    void BuildRAIndex(size_t n_sq, const float *sq_data, size_t n_bp, const float *bp_data,
                                const Parameters &parameters);
    void BuildRAIndexwithData(size_t n_sq, const float *sq_data, size_t n_bp, const float *bp_data,
                                const Parameters &parameters);
    void BuildRAIndexwithDataSQ(size_t n_sq, const float *sq_data, size_t n_bp, faiss::IndexScalarQuantizer *sq, const float *bp_data,
                                const Parameters &parameters);
    void SetSQforRAIndex(size_t n_sq, const float *sq_data, size_t n_bp, faiss::IndexScalarQuantizer *sq,
                                const Parameters &parameters);

    void BuildRAIndexwithDatanoConn(size_t n_sq, const float *sq_data, size_t n_bp, const float *bp_data,
                                const Parameters &parameters);
    virtual void Build(size_t n, const float *data, const Parameters &parameters) override;

    inline void SetRetrAttenParameters(const Parameters &parameters) {}
    uint32_t SearchRetrAttenGraph(const float *query, size_t k, size_t &qid, const Parameters &parameters,
                                  unsigned *indices);
    void LinkRetrAtten(const Parameters &parameters, SimpleNeighbor *simple_graph);
    void LinkOneNode(const Parameters &parameters, uint32_t nid, SimpleNeighbor *simple_graph, bool is_base,
                     boost::dynamic_bitset<> &visited);

    void SearchRetrAttenbyBase(const float *query, uint32_t nid, const Parameters &parameters,
                               SimpleNeighbor *simple_graph, NeighborPriorityQueue &search_pool,
                               boost::dynamic_bitset<> &visited, std::vector<Neighbor> &full_retset);

    void SearchRetrAttenbyQuery(const float *query, uint32_t nid, const Parameters &parameters,
                                SimpleNeighbor *simple_graph, NeighborPriorityQueue &search_pool,
                                boost::dynamic_bitset<> &visited, std::vector<Neighbor> &full_retset);

    void LoadVectorData(const char *base_file, const char *sampled_query_file);

    CompactGraph &GetRetrAttenGraph() { return bipartite_graph_; }
    inline void InitRetrAttenGraph() { bipartite_graph_.resize(total_pts_); }

    inline void LoadSearchNeededData(const char *base_file, const char *sampled_query_file) {
        LoadVectorData(base_file, sampled_query_file);
    }

    void PruneCandidates(std::vector<Neighbor> &search_pool, uint32_t tgt_id, const Parameters &parameters,
                         std::vector<uint32_t> &pruned_list, boost::dynamic_bitset<> &visited);
    void AddReverse(NeighborPriorityQueue &search_pool, uint32_t src_node, std::vector<uint32_t> &pruned_list,
                    const Parameters &parameters, boost::dynamic_bitset<> &visited);

    void RetrAttenProjectionReserveSpace(const Parameters &parameters);

    void CalculateProjectionep();

    void LinkProjection(const Parameters &parameters);
    void LinkProjectionNoConn(const Parameters &parameters);

    // void LinkProjectionSQ(const Parameters &parameters);
    // void LinkBase(const Parameters &parameters, SimpleNeighbor *simple_graph);

    void TrainingLink2Projection(const Parameters &parameters, SimpleNeighbor *simple_graph);

    void SearchProjectionbyQuery(const float *query, const Parameters &parameters, NeighborPriorityQueue &search_pool,
                                 boost::dynamic_bitset<> &visited, std::vector<Neighbor> &full_retset);

    uint32_t PruneProjectionCandidates(std::vector<Neighbor> &search_pool, const float *query, uint32_t qid,
                                       const Parameters &parameters, std::vector<uint32_t> &pruned_list);

    void PruneProjectionBaseSearchCandidates(std::vector<Neighbor> &search_pool, const float *query, uint32_t qid,
                                             const Parameters &parameters, std::vector<uint32_t> &pruned_list);

    void ProjectionAddReverse(uint32_t src_node, const Parameters &parameters);

    void SupplyAddReverse(uint32_t src_node, const Parameters &parameters);

    void PruneProjectionReverseCandidates(uint32_t src_node, const Parameters &parameters,
                                          std::vector<uint32_t> &pruned_list);

    void PruneProjectionInternalReverseCandidates(uint32_t src_node, const Parameters &parameters,
                                                                  std::vector<uint32_t> &pruned_list);

    std::pair<uint32_t, uint32_t> SearchRAIndex(const float *query, size_t k, size_t &qid, const Parameters &parameters,
                                   unsigned *indices, std::vector<float>& res_dists);

    void SaveProjectionGraph(const char *filename);

    void LoadProjectionGraph(const char *filename);

    void LoadNsgGraph(const char *filename);

    void LoadLearnBaseKNN(const char *filename);

    void SetLearnBaseKNNi64(const int64_t* learn_base_knn, uint32_t npts, uint32_t k_dim);
    void SetLearnBaseKNN(const uint32_t* learn_base_knn, uint32_t npts, uint32_t k_dim);
    
    void SaveLayerQIndex(const char *filename);
    void LoadLayerQIndex(const char *filename, const float* data_bp);

    uint32_t SearchRAIndexPy(const float *query, size_t k, size_t &qid, uint32_t L_pq, uint32_t *indices, float* res_dists);

    uint32_t SearchRAIndexIterativelyPy(const float *query, size_t& end_k, unsigned *indices, float* res_dists);

    uint32_t SearchRAIndexPyReturnFullVisitedSet(const float *query, size_t k, size_t &qid, uint32_t L_pq, uint32_t *all_visited, float *all_visited_dists);

    uint32_t SearchRAIndexPySQ(const float *query, size_t k, size_t &qid, uint32_t L_pq, unsigned *indices, float* res_dists);

    void LoadBaseLearnKNN(const char *filename);

    inline std::vector<std::vector<uint32_t>> &GetProjectionGraph() { return projection_graph_; }

    uint32_t PruneProjectionRetrAttenCandidates(std::vector<Neighbor> &search_pool, const float *query, uint32_t qid,
                                                const Parameters &parameters, std::vector<uint32_t> &pruned_list);

    void SearchProjectionGraphInternal(NeighborPriorityQueue &search_queue, const float *query, uint32_t tgt,
                                       const Parameters &parameters, boost::dynamic_bitset<> &visited,
                                       std::vector<Neighbor> &full_retset);

    void PruneBiSearchBaseGetBase(std::vector<Neighbor> &search_pool, const float *query, uint32_t qid,
                                  const Parameters &parameters, std::vector<uint32_t> &pruned_list);

    void PruneLocalJoinCandidates(uint32_t node, const Parameters &parameters, uint32_t candi);

    void CollectPoints(const Parameters &parameters);

    void dfs(boost::dynamic_bitset<> &flag, unsigned root, unsigned &cnt);

    void findroot(boost::dynamic_bitset<> &flag, unsigned &root, const Parameters &parameter);

    void InitVisitedListPool(uint32_t num_threads) { 
        // std::cout << "Newing max pool, num ele: " << static_cast<int>(std::max(u32_nd_, u32_total_pts_)) << std::endl;
        if (visited_list_pool_ != nullptr) {
            delete visited_list_pool_;
        }
        visited_list_pool_ = new VisitedListPool(num_threads, (int)std::max(u32_nd_, u32_total_pts_)); 
    };

    void getithLearnNN(uint32_t i, uint32_t *learn_nn);

    void qbaseNNbipartite(const Parameters &parameters);

    void SimulateRAIndexInsertOneKey(const Parameters &parameters, uint32_t& wait_to_add_num, const float *key_data, std::vector<uint32_t>& closest_q_ids);
    
    void SimulateRAIndexBatchInsert(const Parameters &parameters, uint32_t& wait_to_add_num, const float *key_data, std::vector<uint32_t>& closest_q_ids);

    std::pair<uint32_t, uint32_t> SearchRetrAttenGraph(const float *query, size_t k, size_t &qid, const Parameters &parameters,
                                              unsigned *indices, std::vector<float>& dists);

    void SaveBaseData(const char *filename);

    void LoadBaseData(const char *filename);

    bool check_valid_range(uint32_t x) {
        if (unlikely(x >= total_pts_)) return false;
        if (unlikely(x < 0)) return false;
        return true;
    }

    std::vector<std::vector<uint32_t>>& get_learn_base_knn() {
        return learn_base_knn_;
    }
    
    void set_projection_ep(uint32_t ep) { projection_ep_ = ep; }

    Index *initializer_;
    TimeMetric dist_cmp_metric;
    TimeMetric memory_access_metric;
    TimeMetric block_metric;
    VisitedListPool *visited_list_pool_{nullptr};
    bool need_normalize = false;
    bool has_sq = false;

    IterativeSearchState *iterative_search_state = nullptr;

   protected:
    std::vector<std::vector<uint32_t>> bipartite_graph_;
    std::vector<std::vector<uint32_t>> final_graph_;
    std::vector<std::vector<uint32_t>> projection_graph_;
    std::vector<std::vector<uint32_t>> supply_nbrs_;
    std::vector<std::vector<uint32_t>> learn_base_knn_;
    std::vector<std::vector<uint32_t>> base_learn_knn_;

   private:
    const size_t total_pts_const_;
    size_t total_pts_;
    Distance *l2_distance_;
    uint32_t width_;
    std::set<uint32_t> sq_en_set_;
    std::set<uint32_t> bp_en_set_;
    std::mutex sq_set_mutex_;
    std::mutex bp_set_mutex_;
    std::vector<std::mutex> locks_;
    uint32_t u32_nd_;
    uint32_t u32_nd_sq_;
    uint32_t u32_total_pts_;
    uint32_t projection_ep_;

    // scalar quantizer store
    faiss::IndexScalarQuantizer *scalar_quant_;
    faiss::FlatCodesDistanceComputer *sqdc_;
};
}  // namespace efanna2e