#include <omp.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <algorithm>
#include <boost/dynamic_bitset.hpp>
#include <boost/program_options.hpp>
#include <chrono>
#include <cmath>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

#include <filesystem>
#include <unistd.h>
#include <chrono>

#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/AutoTune.h>
#include <faiss/index_factory.h>
#include <faiss/gpu/GpuIndex.h>
#include <faiss/gpu/GpuAutoTune.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/impl/ScalarQuantizer.h>
#include <faiss/IndexScalarQuantizer.h>

#include "index_ra.h"
#include "efanna2e/distance.h"
#include "efanna2e/neighbor.h"
#include "efanna2e/parameters.h"
#include "efanna2e/util.h"

namespace py = pybind11;

template <typename T>



class LayerQHeadRAIndex {
    public:
        int query_head_num;
        int kv_head_num;
        int query_group_num;
        uint32_t dim;
        std::string index_type = "RAIndex";
        std::vector<efanna2e::IndexRetrAtten *> index_array;
        std::vector<faiss::IndexFlatIP*> flat_array;
        std::vector<faiss::gpu::GpuIndexFlatIP *> gpu_flat_array;
        std::vector<faiss::gpu::StandardGpuResources *> gpu_res_array;
        std::vector<float*> key_data_array;
        std::vector<faiss::IndexScalarQuantizer*> sq_array;
        
        uint32_t prefill_num = 0;
        uint32_t index_num = 0;
        uint32_t flat_num = 0;
        uint32_t one_head_base_num = 0;
        bool graph_is_built = false;
        bool use_sq = false;

        LayerQHeadRAIndex(const int query_head_num, const int kv_head_num, const uint32_t dim) {
            this->query_head_num = query_head_num;
            this->kv_head_num = kv_head_num;
            this->dim = dim;
            this->query_group_num = query_head_num / kv_head_num;
            index_array.resize(query_head_num);

            gpu_flat_array.resize(kv_head_num);
            gpu_res_array.resize(kv_head_num);
            key_data_array.resize(kv_head_num);

            for (int i = 0; i < query_head_num; i++) {
                efanna2e::IndexRetrAtten *index = new efanna2e::IndexRetrAtten(dim, 0, efanna2e::Metric::INNER_PRODUCT, nullptr);
                index_array[i] = index;
            }

            #pragma omp parallel for num_threads(kv_head_num)
            for (int i = 0; i < kv_head_num; i++) {
                key_data_array[i] = nullptr;
            }
        }

        ~LayerQHeadRAIndex() {
            for (int i = 0; i < query_head_num; i++) {
                delete index_array[i];
            }
            for (int i = 0; i < kv_head_num; i++) {
                if (use_sq) {
                    delete sq_array[i];
                }
                if (!use_sq) {
                    delete[] key_data_array[i];
                }
            }
            malloc_trim(0);
        }

        void free_gpu_res() {
            for (int i = 0; i < kv_head_num; i++) {
                delete gpu_res_array[i];
                delete gpu_flat_array[i];
            }
        }


        void build(uint32_t sq_num, uint32_t k_dim, uint32_t base_num, uint32_t M_sq, uint32_t M_pjbp, 
                        uint32_t L_pjpq, uint32_t num_threads,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &query_data,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &base_data) {
            
            omp_set_nested(1);
            one_head_base_num = base_num;
            auto query_data_proxy = query_data.unchecked();
            auto base_data_proxy = base_data.unchecked();
            int outer_parallel_num = omp_get_max_threads() / num_threads;

            int64_t* query_nn = new int64_t[query_head_num * sq_num * k_dim];
            float* temp_q_res_dists = new float[query_head_num * sq_num * k_dim];
            efanna2e::Parameters parameters;
            parameters.Set<uint32_t>("M_sq", M_sq);
            parameters.Set<uint32_t>("M_pjbp", M_pjbp);
            parameters.Set<uint32_t>("L_pjpq", L_pjpq);
            parameters.Set<uint32_t>("num_threads", num_threads);
            #pragma omp parallel for schedule(dynamic) num_threads(outer_parallel_num)
            for (int i = 0; i < query_head_num; i++) {
                omp_set_num_threads(num_threads);
                const float* query_data_ptr = &(query_data_proxy(i, 0, 0));
                int64_t* learn_base_nn_ptr = query_nn + i * sq_num * k_dim;
                float* learn_base_nn_dists = temp_q_res_dists + i * sq_num * k_dim;
                int key = i / query_group_num;
                const float* base_data_ptr = &(base_data_proxy(key, 0, 0));
                #pragma omp critical
                {
                    if (key_data_array[key] == nullptr) {
                            key_data_array[key] = new float[base_num * dim];
                            memcpy(key_data_array[key], base_data_ptr, base_num * dim * sizeof(float));
                    }
                }

                if (key_data_array[key] == nullptr) {
                    std::cout << "key_data_array[" << key << "] is nullptr" << std::endl;
                    std::cerr << "key_data_array[" << key << "] is nullptr" << std::endl;
                    throw std::runtime_error("key_data_array is nullptr");
                }

                auto gpu_res = new faiss::gpu::StandardGpuResources();
                auto gpu_index = new faiss::gpu::GpuIndexFlatIP(gpu_res, dim);
                gpu_index->add(base_num, key_data_array[key]);
                gpu_index->search(sq_num, query_data_ptr, k_dim, learn_base_nn_dists, learn_base_nn_ptr);
                
                std::vector<int64_t> nn(sq_num);
                std::vector<float> dists(sq_num);
                int64_t max_dist_id = 0;
                float max_dist = -99999;
                for (int j = 0; j < sq_num; j++) {
                    if (learn_base_nn_dists[j * k_dim] > max_dist) {
                        max_dist = learn_base_nn_dists[j * k_dim];
                        max_dist_id = learn_base_nn_ptr[j * k_dim];
                    }
                }
                index_array[i]->set_projection_ep(uint32_t(max_dist_id));

                index_array[i]->SetLearnBaseKNNi64(learn_base_nn_ptr, sq_num, k_dim);
                index_array[i]->BuildRAIndexwithData(sq_num, nullptr, base_num, key_data_array[key], parameters);
                index_array[i]->InitVisitedListPool(2);
                gpu_res->noTempMemory();
                delete gpu_res;
                delete gpu_index;
            }
            delete[] query_nn;
            delete[] temp_q_res_dists;

            prefill_num = base_num;
            index_num = base_num;
            graph_is_built = true;
        }

        void buildOneHead(uint32_t sq_num, uint32_t k_dim, uint32_t base_num, uint32_t M_sq, uint32_t M_pjbp, 
                        uint32_t L_pjpq, uint32_t num_threads,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &query_data,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &base_data, int head_idx) {
            omp_set_max_active_levels(4);
            auto query_data_proxy = query_data.unchecked();
            auto base_data_proxy = base_data.unchecked();
            int outer_parallel_num = omp_get_max_threads() / num_threads;

            int64_t* query_nn = new int64_t[1 * sq_num * k_dim];
            float* temp_q_res_dists = new float[1 * sq_num * k_dim];
            efanna2e::Parameters parameters;
            parameters.Set<uint32_t>("M_sq", M_sq);
            parameters.Set<uint32_t>("M_pjbp", M_pjbp);
            parameters.Set<uint32_t>("L_pjpq", L_pjpq);
            parameters.Set<uint32_t>("num_threads", num_threads);

                int i = 0;
                omp_set_num_threads(num_threads);
                const float* query_data_ptr = &(query_data_proxy(i, 0, 0));
                int64_t* learn_base_nn_ptr = query_nn + i * sq_num * k_dim;
                float* learn_base_nn_dists = temp_q_res_dists + i * sq_num * k_dim;
                int key = head_idx / query_group_num;
                const float* base_data_ptr = &(base_data_proxy(0, 0, 0));
                    if (key_data_array[key] == nullptr) {
                            key_data_array[key] = new float[base_num * dim];
                            memcpy(key_data_array[key], base_data_ptr, base_num * dim * sizeof(float));
                    }

                if (key_data_array[key] == nullptr) {
                    std::cout << "key_data_array[" << key << "] is nullptr" << std::endl;
                    std::cerr << "key_data_array[" << key << "] is nullptr" << std::endl;
                    throw std::runtime_error("key_data_array is nullptr");
                }

                auto gpu_res = new faiss::gpu::StandardGpuResources();
                auto gpu_index = new faiss::gpu::GpuIndexFlatIP(gpu_res, dim);
                gpu_index->add(base_num, key_data_array[key]);
                gpu_index->search(sq_num, query_data_ptr, k_dim, learn_base_nn_dists, learn_base_nn_ptr);
                int64_t max_dist_id = 0;
                float max_dist = -99999;
                for (int j = 0; j < sq_num; j++) {
                    if (learn_base_nn_dists[j * k_dim] > max_dist) {
                        max_dist = learn_base_nn_dists[j * k_dim];
                        max_dist_id = learn_base_nn_ptr[j * k_dim];
                    }
                }
                index_array[head_idx]->set_projection_ep(uint32_t(max_dist_id));

                index_array[head_idx]->SetLearnBaseKNNi64(learn_base_nn_ptr, sq_num, k_dim);
                index_array[head_idx]->BuildRAIndexwithData(sq_num, nullptr, base_num, key_data_array[key], parameters);
                index_array[head_idx]->InitVisitedListPool(2);
                
                gpu_res->noTempMemory();
                delete gpu_res;
                delete gpu_index;
            delete[] query_nn;
            delete[] temp_q_res_dists;
            malloc_trim(0);
            prefill_num = base_num;
            index_num = base_num;
            graph_is_built = true;
        }
        
        void buildSQ(uint32_t sq_num, uint32_t k_dim, uint32_t base_num, uint32_t M_sq, uint32_t M_pjbp, 
                        uint32_t L_pjpq, uint32_t num_threads,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &query_data,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &base_data) {
            use_sq = true;
            one_head_base_num = base_num;
            
            omp_set_nested(1);
            auto query_data_proxy = query_data.unchecked();
            auto base_data_proxy = base_data.unchecked();
            
            int outer_parallel_num = omp_get_max_threads() / num_threads;
            int64_t* query_nn = new int64_t[query_head_num * sq_num * k_dim];
            float* temp_q_res_dists = new float[query_head_num * sq_num * k_dim];
            efanna2e::Parameters parameters;
            parameters.Set<uint32_t>("M_sq", M_sq);
            parameters.Set<uint32_t>("M_pjbp", M_pjbp);
            parameters.Set<uint32_t>("L_pjpq", L_pjpq);
            parameters.Set<uint32_t>("num_threads", num_threads);
            
            #pragma omp parallel for schedule(dynamic) num_threads(outer_parallel_num)
            for (int i = 0; i < query_head_num; i++) {
                omp_set_num_threads(num_threads);
                const float* query_data_ptr = &(query_data_proxy(i, 0, 0));
                int64_t* learn_base_nn_ptr = query_nn + i * sq_num * k_dim;
                float* learn_base_nn_dists = temp_q_res_dists + i * sq_num * k_dim;
                int key = i / query_group_num;
                
                const float* base_data_ptr = &(base_data_proxy(key, 0, 0));
                
                #pragma omp critical
                {
                    if (key_data_array[key] == nullptr) {
                            
                            key_data_array[key] = new float[base_num * dim];
                            memcpy(key_data_array[key], base_data_ptr, base_num * dim * sizeof(float));
                    }
                }
                auto gpu_res = faiss::gpu::StandardGpuResources();
                auto gpu_index = faiss::gpu::GpuIndexFlatIP(&gpu_res, dim);
                gpu_index.add(base_num, key_data_array[key]);
                
                gpu_index.search(sq_num, query_data_ptr, k_dim, learn_base_nn_dists, learn_base_nn_ptr);
                index_array[i]->SetLearnBaseKNNi64(learn_base_nn_ptr, sq_num, k_dim);
                std::vector<int64_t> nn(sq_num);
                std::vector<float> dists(sq_num);
                int64_t max_dist_id = 0;
                float max_dist = -99999;
                for (int j = 0; j < sq_num; j++) {
                    if (learn_base_nn_dists[j * k_dim] > max_dist) {
                        max_dist = learn_base_nn_dists[j * k_dim];
                        max_dist_id = learn_base_nn_ptr[j * k_dim];
                    }
                }
                index_array[i]->set_projection_ep(uint32_t(max_dist_id));
                index_array[i]->BuildRAIndexwithData(sq_num, nullptr, base_num, key_data_array[key], parameters);
                index_array[i]->InitVisitedListPool(2);
                
                gpu_res.noTempMemory();
            }
            delete[] query_nn;
            delete[] temp_q_res_dists;

            for (int i = 0; i < kv_head_num; ++i) {
                delete[] key_data_array[i];
                key_data_array[i] = nullptr;
            }
            #pragma omp parallel for schedule(dynamic) num_threads(outer_parallel_num)
            for (int i = 0; i < query_head_num; i++) {
                omp_set_num_threads(num_threads);
                const float* query_data_ptr = &(query_data_proxy(i, 0, 0));
                int64_t* learn_base_nn_ptr = query_nn + i * sq_num * k_dim;
                float* learn_base_nn_dists = temp_q_res_dists + i * sq_num * k_dim;
                int key = i / query_group_num;
                const float* base_data_ptr = &(base_data_proxy(key, 0, 0));
                
                #pragma omp critical
                {
                    if (sq_array[key] == nullptr) {
                            sq_array[key] = new faiss::IndexScalarQuantizer(dim, faiss::ScalarQuantizer::QuantizerType::QT_8bit, faiss::MetricType::METRIC_INNER_PRODUCT);
                            sq_array[key]->train(base_num, base_data_ptr);
                            sq_array[key]->add(base_num, base_data_ptr);
                    }
                }

                index_array[i]->SetSQforRAIndex(sq_num, nullptr, base_num, sq_array[key], parameters);
            }

            malloc_trim(0);
            prefill_num = base_num;
            index_num = base_num;
            graph_is_built = true;
        }

        void buildOneHeadSQ(uint32_t sq_num, uint32_t k_dim, uint32_t base_num, uint32_t M_sq, uint32_t M_pjbp, 
                        uint32_t L_pjpq, uint32_t num_threads,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &query_data,
                        py::array_t<float, py::array::c_style | py::array::forcecast> &base_data, int head_idx) {
            use_sq = true;
            omp_set_nested(1);
            auto query_data_proxy = query_data.unchecked();
            auto base_data_proxy = base_data.unchecked();
            
            int outer_parallel_num = omp_get_max_threads() / num_threads;
            int64_t* query_nn = new int64_t[sq_num * k_dim];
            float* temp_q_res_dists = new float[sq_num * k_dim];
            efanna2e::Parameters parameters;
            parameters.Set<uint32_t>("M_sq", M_sq);
            parameters.Set<uint32_t>("M_pjbp", M_pjbp);
            parameters.Set<uint32_t>("L_pjpq", L_pjpq);
            parameters.Set<uint32_t>("num_threads", num_threads);
                int i = 0;
                omp_set_num_threads(num_threads);
                const float* query_data_ptr = &(query_data_proxy(i, 0, 0));
                int64_t* learn_base_nn_ptr = query_nn + i * sq_num * k_dim;
                float* learn_base_nn_dists = temp_q_res_dists + i * sq_num * k_dim;
                int key = head_idx / query_group_num;
                const float* base_data_ptr = &(base_data_proxy(0, 0, 0));
                    if (key_data_array[key] == nullptr) {
                            
                            key_data_array[key] = new float[base_num * dim];
                    }
                auto gpu_res = new faiss::gpu::StandardGpuResources();
                auto gpu_index = new faiss::gpu::GpuIndexFlatIP(gpu_res, dim);
                gpu_index->add(base_num, key_data_array[key]);
                
                gpu_index->search(sq_num, query_data_ptr, k_dim, learn_base_nn_dists, learn_base_nn_ptr);
                index_array[head_idx]->SetLearnBaseKNNi64(learn_base_nn_ptr, sq_num, k_dim);
                index_array[head_idx]->BuildRAIndexwithData(sq_num, nullptr, base_num, key_data_array[key], parameters);
                index_array[head_idx]->InitVisitedListPool(2);
                
                gpu_res->noTempMemory();
                delete gpu_res;
                delete gpu_index;
            delete[] query_nn;
            delete[] temp_q_res_dists;

            malloc_trim(0);
                i = 0;
                omp_set_num_threads(num_threads);
                    if (sq_array[key] == nullptr) {
                            sq_array[key] = new faiss::IndexScalarQuantizer(dim, faiss::ScalarQuantizer::QuantizerType::QT_8bit, faiss::MetricType::METRIC_INNER_PRODUCT);
                            sq_array[key]->train(base_num, base_data_ptr);
                            sq_array[key]->add(base_num, base_data_ptr);
                    }
                index_array[head_idx]->SetSQforRAIndex(sq_num, nullptr, base_num, sq_array[key], parameters);
            if (head_idx == query_head_num - 1) {
                for (int j = 0; j < kv_head_num; ++j) {
                    if (key_data_array[j] != nullptr) {
                        delete[] key_data_array[j];
                        key_data_array[j] = nullptr;
                    }
                }
            }

            prefill_num = base_num;
            index_num = base_num;
            graph_is_built = true;
        }


        void saveAllIndex(const std::string &dir) {
            for (int i = 0; i < query_head_num; i++) {
                std::string file = dir + "/ra_index_" + std::to_string(i);
                index_array[i]->SaveLayerQIndex(file.c_str());
            }
        }

        void loadAllIndex(const std::string &dir) {
            for (int i = 0; i < query_head_num; i++) {
                int key = i / query_group_num;
                    if (key_data_array[key] == nullptr) {
                        std::string data_file_name = dir + "/ra_index_" + std::to_string(i) + ".data";
                        std::ifstream data_in(data_file_name, std::ios::binary);
                        if (!data_in.is_open()) {
                            throw std::runtime_error("cannot open file");
                        }
                        
                        data_in.seekg(0, data_in.end);
                        size_t file_size = data_in.tellg();
                        data_in.seekg(0, data_in.beg);
                        size_t num_float = file_size / sizeof(float);
                        key_data_array[key] = new float[num_float];
                        data_in.read((char *)key_data_array[key], sizeof(float) * num_float);
                        data_in.close();
                    }
            }

            for (int i = 0; i < query_head_num; i++) {
                std::string file = dir + "/ra_index_" + std::to_string(i);
                
                int key = i / query_group_num;
                index_array[i]->LoadLayerQIndex(file.c_str(), key_data_array[key]);
            }
        }
        

        void searchRAIndex(py::array_t<float, py::array::c_style | py::array::forcecast> &q_data, size_t k, uint32_t L_pq,
                py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &res_id, py::array_t<float, py::array::c_style | py::array::forcecast> &res_dist,
                size_t q_num, uint32_t num_threads) {
            auto q_data_proxy = q_data.unchecked();
            
            auto res_id_proxy = res_id.mutable_unchecked();
            auto res_dist_proxy = res_dist.mutable_unchecked();
            omp_set_nested(1);
        #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
            for (size_t i = 0; i < query_head_num; i++) {
                    index_array[i]->SearchRAIndexPy(&q_data_proxy(i, 0, 0), k, i, L_pq, &res_id_proxy(i, 0, 0), &res_dist_proxy(i, 0, 0));
            }
        }


        void searchRAIndexGetCmps(py::array_t<float, py::array::c_style | py::array::forcecast> &q_data, size_t k, uint32_t L_pq,
                py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &res_id, py::array_t<float, py::array::c_style> &res_dist,
                size_t q_num, uint32_t num_threads, py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &cmps_array) {
            auto q_data_proxy = q_data.unchecked<3>();
            auto res_id_proxy = res_id.mutable_unchecked<3>();
            auto res_dist_proxy = res_dist.mutable_unchecked();
            auto cmps_array_proxy = cmps_array.mutable_unchecked();
            uint32_t* cmps_array_ptr = &cmps_array_proxy(0);
        #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
            for (size_t i = 0; i < (size_t)query_head_num; i++) {
                    if (unlikely(q_num > 1)) {
                        for (size_t j = 0; j < q_num; j++) {
                            auto cmps = index_array[i]->SearchRAIndexPy(&q_data_proxy(i, j, 0), k, j, L_pq, &res_id_proxy(i, j, 0), &res_dist_proxy(i, j, 0));
                            #pragma omp atomic
                            cmps_array_ptr[i] += cmps;
                        }
                    } else {
                        cmps_array_ptr[i] = index_array[i]->SearchRAIndexPy(&q_data_proxy(i, 0, 0), k, i, L_pq, &res_id_proxy(i, 0, 0), &res_dist_proxy(i, 0, 0));
                    }
            }
        }
        


        void searchRAIndexGetCmpsAllVisited(py::array_t<float, py::array::c_style | py::array::forcecast> &q_data, size_t k, uint32_t L_pq,
                py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &res_id, py::array_t<float, py::array::c_style> &res_dist,
                size_t q_num, uint32_t num_threads, py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &cmps_array) {
            auto q_data_proxy = q_data.unchecked<3>();
            auto res_id_proxy = res_id.mutable_unchecked<3>();
            auto res_dist_proxy = res_dist.mutable_unchecked();
            auto cmps_array_proxy = cmps_array.mutable_unchecked();
            uint32_t* cmps_array_ptr = &cmps_array_proxy(0);
        #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
            for (size_t i = 0; i < (size_t)query_head_num; i++) {
                    if (unlikely(q_num > 1)) {
                        for (size_t j = 0; j < q_num; j++) {
                            auto cmps = index_array[i]->SearchRAIndexPyReturnFullVisitedSet(&q_data_proxy(i, j, 0), k, j, L_pq, &res_id_proxy(i, j, 0), &res_dist_proxy(i, j, 0));
                            #pragma omp atomic
                            cmps_array_ptr[i] += cmps;
                        }
                    } else {
                        cmps_array_ptr[i] = index_array[i]->SearchRAIndexPyReturnFullVisitedSet(&q_data_proxy(i, 0, 0), k, i, L_pq, &res_id_proxy(i, 0, 0), &res_dist_proxy(i, 0, 0));
                    }
            }
        }

        void searchRAIndexGetCmpsSQ(py::array_t<float, py::array::c_style | py::array::forcecast> &q_data, size_t k, uint32_t L_pq,
                py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &res_id, py::array_t<float, py::array::c_style> &res_dist,
                size_t q_num, uint32_t num_threads, py::array_t<uint32_t, py::array::c_style | py::array::forcecast> &cmps_array) {
            if (sq_array[0] == nullptr) {
                throw std::runtime_error("sq_array is nullptr");
            }
            if (index_array[0]->has_sq == false) {
                throw std::runtime_error("index_array[i]->has_sq is false");
            }
            auto q_data_proxy = q_data.unchecked<3>();
            auto res_id_proxy = res_id.mutable_unchecked<3>();
            auto res_dist_proxy = res_dist.mutable_unchecked();
            auto cmps_array_proxy = cmps_array.mutable_unchecked();
            uint32_t* cmps_array_ptr = &cmps_array_proxy(0);
        #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
            for (size_t i = 0; i < (size_t)query_head_num; i++) {
                
                    if (unlikely(q_num > 1)) {
                        for (size_t j = 0; j < q_num; j++) {
                            auto cmps = index_array[i]->SearchRAIndexPySQ(&q_data_proxy(i, j, 0), k, j, L_pq, &res_id_proxy(i, j, 0), &res_dist_proxy(i, j, 0));
                            #pragma omp atomic
                            cmps_array_ptr[i] += cmps;
                        }
                    } else {
                        cmps_array_ptr[i] = index_array[i]->SearchRAIndexPySQ(&q_data_proxy(i, 0, 0), k, i, L_pq, &res_id_proxy(i, 0, 0), &res_dist_proxy(i, 0, 0));
                    }
            }
        }


        py::tuple searchFlat(py::array_t<float, py::array::c_style | py::array::forcecast> &input, size_t k, size_t group, uint32_t num_threads) {
            auto input_buf = input.request();
            float* input_ptr = static_cast<float*>(input_buf.ptr);
            k = flat_num < k ? flat_num : k;
            int head = input.shape(0);
            int length = input.shape(1);
            int features = input.shape(2);
            
            int total_results = head * length * k;
            faiss::idx_t* merged_I = new faiss::idx_t[total_results];
            float* merged_D = new float[total_results];
            
        #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
            for (int row = 0; row < head; ++row) {
                int true_position = row / group;
                faiss::IndexFlat* index = flat_array[true_position];

                const float* input_row_ptr = input_ptr + row * length * features;
                faiss::idx_t* output_I_ptr = merged_I + row * length * k;
                float* output_D_ptr = merged_D + row * length * k;
                
                index->search(length, input_row_ptr, k, output_D_ptr, output_I_ptr);
            }

            py::capsule free_when_done_l(merged_D, [](void* f) {
                delete[] f;
            });
            
            py::capsule free_when_done_d(merged_I, [](void* f) {
                delete[] f;
            });
            
            return py::make_tuple(py::array(total_results, merged_D, free_when_done_l), py::array(total_results, merged_I, free_when_done_d));
        };


        void setThreads(uint32_t num_threads) {
            for (int i = 0; i < query_head_num; i++) {
                index_array[i]->InitVisitedListPool(num_threads);
            }
        }

        uint32_t get_flat_num() {
            return flat_num;
        }

        uint32_t get_index_num() {
            return index_num;
        }

        bool is_graph_built() {
            return graph_is_built;
        }

        std::string& get_index_type() {
            return index_type;
        }

        bool is_use_sq() {
            return use_sq;
        }

        float* get_key_data(uint32_t key) {
            return key_data_array[key];
        }

        py::array_t<float, py::array::c_style | py::array::forcecast> get_key_data_array(uint32_t key) {
            float* key_data_ptr = new float[one_head_base_num * dim];
            memcpy(key_data_ptr, key_data_array[key], one_head_base_num * dim * sizeof(float));
            py::capsule free_when_done(key_data_ptr, [](void* f) {
                delete[] f;
            });

            return py::array(one_head_base_num * dim, key_data_ptr, free_when_done);
        }

};



PYBIND11_MODULE(RAIndex, m) {
    m.doc() = "pybind11 RAIndex plugin";  
    
    py::enum_<efanna2e::Metric>(m, "Metric")
        .value("L2", efanna2e::Metric::L2)
        .value("IP", efanna2e::Metric::INNER_PRODUCT)
        .value("COSINE", efanna2e::Metric::COSINE)
        .export_values();
    py::class_<LayerQHeadRAIndex>(m, "LayerQHeadRAIndex")
        .def(py::init<const int, const int, const uint32_t>())
        .def("build", &LayerQHeadRAIndex::build)
        .def("buildSQ", &LayerQHeadRAIndex::buildSQ)
        .def("buildOneHead", &LayerQHeadRAIndex::buildOneHead)
        .def("buildOneHeadSQ", &LayerQHeadRAIndex::buildOneHeadSQ)
        .def("saveAllIndex", &LayerQHeadRAIndex::saveAllIndex)
        .def("loadAllIndex", &LayerQHeadRAIndex::loadAllIndex)
        .def("searchRAIndex", &LayerQHeadRAIndex::searchRAIndex)
        .def("searchRAIndexGetCmps", &LayerQHeadRAIndex::searchRAIndexGetCmps)
        .def("searchFlat", &LayerQHeadRAIndex::searchFlat)
        .def("searchRAIndexGetCmpsSQ", &LayerQHeadRAIndex::searchRAIndexGetCmpsSQ)
        .def("searchRAIndexGetCmpsAllVisited", &LayerQHeadRAIndex::searchRAIndexGetCmpsAllVisited)
        .def("setThreads", &LayerQHeadRAIndex::setThreads)
        .def("get_flat_num", &LayerQHeadRAIndex::get_flat_num)
        .def("get_index_num", &LayerQHeadRAIndex::get_index_num)
        .def("graph_is_built", &LayerQHeadRAIndex::is_graph_built)
        .def("get_index_type", &LayerQHeadRAIndex::get_index_type)
        .def("is_use_sq", &LayerQHeadRAIndex::is_use_sq)
        .def("get_key_data", &LayerQHeadRAIndex::get_key_data)
        .def("get_key_data_array", &LayerQHeadRAIndex::get_key_data_array);
}