from loguru import logger
import numpy as np
import torch
import faiss

class Faiss(object):
    def __init__(self,conf):
        super(Faiss, self).__init__()
        self.metric = conf.hashing.faissmetric
        self.dname = conf.dataset.name
        self.nlist = 128 # number of clusters


    def index_corpus(self, corpus_embeds, corpus_mask):
        
        quantizer = faiss.IndexFlatIP(corpus_embeds.shape[-1])
        if self.metric == "cosine":
            self.index = faiss.IndexIVFFlat(quantizer, corpus_embeds.shape[-1], self.nlist, faiss.METRIC_INNER_PRODUCT)
            H_c = corpus_embeds
            faiss.normalize_L2(H_c)
        elif self.metric == "l2":
            self.index = faiss.IndexIVFFlat(quantizer, corpus_embeds.shape[-1], self.nlist, faiss.METRIC_L2)
            H_c = corpus_embeds
        else:
            raise ValueError(f"Invalid metric {self.metric}. Use 'cosine' or 'l2'.")




        corpus = []
        node_to_cg_map = {}

        ctr = 0

        for i in range(H_c.shape[0]):
            corpus_idxs = torch.nonzero(
                corpus_mask[i, :, 0] == 1,
                as_tuple=False).squeeze(1).cpu().numpy()
            last_idx = corpus_idxs[-1]

            for j in range(ctr, ctr + last_idx + 1):
                node_to_cg_map[j] = i
            ctr = ctr + last_idx + 1
            # corpus.append(H_c[i][corpus_idxs].cpu().numpy())
            corpus.append(H_c[i][corpus_idxs])
        
        corpus = np.concatenate(corpus, axis=0)
        self.num_corpus_tokens = corpus.shape[0]
        
        # Turn map into numpy array
        self.node_to_cg_arr = np.array([node_to_cg_map[i] for i in range(corpus.shape[0])])
        
        assert not self.index.is_trained
        self.index.train(corpus)
        assert self.index.is_trained

        self.index.add(corpus)
        
      



    def retrieve(self, qid, query_embeds, query_mask,  top_K): 

        q_node_idxs = torch.nonzero(
            query_mask[:, 0] == 1,
            as_tuple=False).squeeze(1).cpu().numpy()

        res_dists, res_cids = self.index.search(
            query_embeds[q_node_idxs],
            min(3 * top_K, self.num_corpus_tokens)) # 3K limit
        
        assert res_dists.shape[0] == q_node_idxs.shape[0]

        # Get the cutoff distance.
        cutoff_pt = res_cids.shape[0] * top_K
        res_dists_flat = res_dists.flatten()
        # print(res_dists.shape, res_dists_flat.shape, cutoff_pt)
        sorted_indices = np.argsort(res_dists_flat)
        cutoff_dist = res_dists_flat[sorted_indices[cutoff_pt]]

        # Collect the results by union
        res_union = set()
        res_inter = set()

        for node_res_idx in range(res_cids.shape[0]):
            # Each row in I contains corpus graph node IDs and -1s if the top-K value was too large.
            # We want only the valid IDs and not the -1 padding.
            valid_cid_idxs = np.where(res_cids[node_res_idx] != -1)
            valid_node_cids = res_cids[node_res_idx][valid_cid_idxs]
            valid_node_dists = res_dists[node_res_idx][valid_cid_idxs]
            valid_node_cids = valid_node_cids[valid_node_dists <= cutoff_dist].astype(int)

            cids = set(self.node_to_cg_arr[valid_node_cids])
            res_union = res_union.union(cids)
            if node_res_idx == 0:
                res_inter = cids
            else:
                res_inter = res_inter.intersection(cids)

        cid_union = list(res_union)
        cid_inter = list(res_inter)
        
        return (cid_union, cid_inter)

            
            