from loguru import logger
import numpy as np
import torch
import diskannpy as dap
import os

class DiskANN(object):
    def __init__(self,conf):
        super(DiskANN, self).__init__()
        self.metric = conf.hashing.diskannmetric
        self.dname = conf.dataset.name

    def index_corpus(self, corpus_embeds, corpus_mask):
        corpus = []
        node_to_cg_map = {}

        ctr = 0

        for i in range(corpus_embeds.shape[0]):
            corpus_idxs = torch.nonzero(
                corpus_mask[i, :, 0] == 1,
                as_tuple=False).squeeze(1).cpu().numpy()
            last_idx = corpus_idxs[-1]

            for j in range(ctr, ctr + last_idx + 1):
                node_to_cg_map[j] = i
            ctr = ctr + last_idx + 1
            # corpus.append(H_c[i][corpus_idxs].cpu().numpy())
            corpus.append(corpus_embeds[i][corpus_idxs])
        
        corpus = np.concatenate(corpus, axis=0)
        self.num_corpus_tokens = corpus.shape[0]
        
        # Turn map into numpy array
        self.node_to_cg_arr = np.array([node_to_cg_map[i] for i in range(corpus.shape[0])])
        
        if not os.path.exists(f"diskann_idx_data/{self.dname}_diskann_{self.metric}_multi_v3_vectors.bin"):
            dap.build_memory_index(
                data=corpus,
                distance_metric=self.metric,
                index_directory='./',
                graph_degree=16,
                complexity=32,
                alpha=1.2,
                num_threads=0,
                use_pq_build=False,
                num_pq_bytes=8,
                use_opq=False,
                filter_complexity=32,
                index_prefix=f"diskann_idx_data/{self.dname}_diskann_{self.metric}_multi_v3",
            )
        
        self.node_index = dap.StaticMemoryIndex(
            index_directory='./',
            num_threads=16,
            initial_search_complexity=2 ** 21,
            index_prefix=f"diskann_idx_data/{self.dname}_diskann_{self.metric}_multi_v3",
        )



    def retrieve(self, qid, query_embeds, query_mask,  top_K): 

        q_node_idxs = torch.nonzero(
            query_mask[:, 0] == 1,
            as_tuple=False).squeeze(1).cpu().numpy()
        # Perform the search on valid nodes for this query
        # num valid nodes x num neighbors
        res_cids, res_dists = self.node_index.batch_search(
            query_embeds[q_node_idxs],
            k_neighbors=min(3 * top_K, self.num_corpus_tokens), # 3K limit
            complexity=5,
            num_threads=16
        )

        # Get the cutoff distance.
        cutoff_pt = res_cids.shape[0] * top_K
        res_dists_flat = res_dists.flatten()
        # print(res_dists.shape, res_dists_flat.shape, cutoff_pt)
        sorted_indices = np.argsort(res_dists_flat)
        cutoff_dist = res_dists_flat[sorted_indices[cutoff_pt]]

        # Collect the results by union
        res_union = set()
        res_inter = set()

        # num valid nodes x num top-K neighbors
        for node_res_idx in range(res_cids.shape[0]):
            # Each row in res_cids contains corpus graph node IDs and 0s if the top-K value was too large.
            # We want only the valid IDs and not the 0 padding.
            # We also don't want any corpus graph IDs that are out of bounds.
            # And finally, we want the distance to below the cutoff distance.
            valid_cid_idxs = np.where((res_cids[node_res_idx] != 0) & (res_cids[node_res_idx] < self.num_corpus_tokens))[0]
            valid_node_cids = res_cids[node_res_idx][valid_cid_idxs]
            valid_node_dists = res_dists[node_res_idx][valid_cid_idxs]
            valid_node_cids = valid_node_cids[valid_node_dists <= cutoff_dist]

            cids = set(self.node_to_cg_arr[valid_node_cids])
            res_union = res_union.union(cids)
            if node_res_idx == 0:
                res_inter = cids
            else:
                res_inter = res_inter.intersection(cids)

        cid_union = list(res_union)
        cid_inter = list(res_inter)
        
        return (cid_union, cid_inter)

            
            