# from lsh.lsh_base import BaseLSH
from loguru import logger
import time 
import psutil
from lsh.lsh_utils import fetch_gaussian_hyperplanes, fetch_n_omega_samples, get_smoothing_factor
import numpy as np
from collections import defaultdict
from utils.model_utils import nanl_fast_inference_without_model
import torch
import heapq

def user_time():
    return psutil.Process().cpu_times().user


class Ghash2_Raw(object):
    def __init__(self,conf):
        super(Ghash2_Raw, self).__init__()
        self.num_hash_tables = None
        # No. of buckets in a hashTable is 2^subset_size
        self.rel_mode = conf.dataset.rel_mode
        self.subset_size = conf.hashing.subset_size
        self.subset_type = conf.hashing.subset_type # TODO: check if we need to add to hashing_config_name_removal_set
        self.hcode_dim = conf.hashing.hcode_dim
        self.device = conf.hashing.device
        self.DEBUG = conf.DEBUG  #Additional print statements
        self.conf = conf
        self.embed_dim = conf.gmn.filters_3
        self.max_set_size = conf.dataset.actual_max_node_set_size
        self.m_use = conf.hashing.m_use
        self.T = conf.hashing.T
        self.some_preprocessing_for_speedup()
        # self.gauss_hplanes_cos = fetch_gaussian_hyperplanes(self.hcode_dim, self.m_use * self.max_set_size * 4)
        if self.rel_mode in ["ged", "uneq_ged"]:
            self.gauss_hplanes_cos = fetch_gaussian_hyperplanes(self.hcode_dim, 2 * self.m_use * self.max_set_size * 4)  # two fmaps are concatenated
        else:
            self.gauss_hplanes_cos = fetch_gaussian_hyperplanes(self.hcode_dim, self.m_use * self.max_set_size * 4)
        # self.powers_of_two = 1 << np.arange(self.hcode_dim - 1, -1, -1)  # MISTAKE
        self.powers_of_two = 1 << np.arange(self.subset_size - 1, -1, -1)
        # NOTE: We have one hash table per dimension
        self.num_hash_tables = self.embed_dim

        # generates self.hash_functions, containing indices for hash functions
        self.init_hash_functions()

        ### Timing methods and funcs
        self.timing_methods = ["real", "user", "process_time"]
        self.timing_funcs = {"real": time.time, "user": user_time, "process_time": time.process_time}
        if self.rel_mode == "uneq_ged":
            self.sqrt_add_cost = np.sqrt(1).astype(np.float32)
            self.sqrt_del_cost = np.sqrt(2).astype(np.float32)
            self.normalizing_const = np.sqrt(3).astype(np.float32)
        elif self.rel_mode == "ged":
            self.sqrt_add_cost = np.sqrt(1).astype(np.float32)
            self.sqrt_del_cost = np.sqrt(1).astype(np.float32)
            self.normalizing_const = np.sqrt(2).astype(np.float32)
        else: # below variables should be unused
            self.add_cost = None
            self.del_cost = None
        # Unused below
        # self.K = conf.K  #TODO: Decide if to use this
        # self.scoring_func = scoring_func_dict[conf.hashing.FUNC] #TODO


    # def init_aux_info(self, aux_info_dict):
    #     self.sinkhorn_temp = aux_info_dict['sinkhorn_temp']
    #     self.torch_masked_features_corpus = aux_info_dict['masked_features_corpus'].to(self.device)
    #     assert self.torch_masked_features_corpus.shape[0] == self.num_corpus_items
    #     assert self.torch_masked_features_corpus.shape[1] == self.max_set_size
        
    #     self.torch_corpus_embeds = torch.from_numpy(self.corpus_embeds).to(self.device)

    def init_hash_functions(self):
        """
            Each hash function is a random subset of the hashcode. 
            Based on the subset_type, we may keep same subset across dimensions, or vary. 
        """
        if self.subset_size == self.hcode_dim:
            # There is no sampling, all bits are used across all dimensions (hash tables)
            self.hash_functions = np.repeat(np.arange(self.hcode_dim)[None,:],self.num_hash_tables,axis=0).astype(np.int64)
            return
        
        hash_functions = []
        if self.subset_type == "same":
            sampled_subset = np.random.choice(self.hcode_dim, self.subset_size, replace=False).astype(np.int64)
            for i in range(self.num_hash_tables):
                hash_functions.append(sampled_subset)
        elif self.subset_type == "different":
            for i in range(self.num_hash_tables):
                hash_functions.append(np.random.choice(self.hcode_dim, self.subset_size, replace=False).astype(np.int64))
        
        self.hash_functions = np.stack(hash_functions)
        

    def some_preprocessing_for_speedup(self):
        self.ws, self.pdfs = fetch_n_omega_samples(self.conf, self.max_set_size*self.conf.hashing.m_use)
        assert(len(self.ws) == len(self.pdfs)) , f"Size mismatch{len(self.pdfs) }, {len(self.ws)}"

        self.sqrt_pdfs = np.sqrt(self.pdfs)
        self.R_G = 2 * (np.sin(self.ws * self.T/2))**2 / self.ws**2 + \
                self.T * np.sin(self.ws * self.T) / self.ws
        self.I_G = np.sin(self.ws * self.T) / self.ws**2 - \
                self.T * np.cos(self.ws * self.T) / self.ws

        if self.conf.hashing.Sm != 'none':
            Re_s,Im_s = get_smoothing_factor(self.conf,self.ws,self.T)
            self.R_G = self.R_G * Re_s - self.I_G * Im_s
            self.I_G = self.R_G * Im_s + self.I_G * Re_s

        self.sign_RG = np.sign(self.R_G)
        self.sign_IG = np.sign(self.I_G)
        self.sqrt_abs_RG = np.sqrt(np.abs(self.R_G))
        self.sqrt_abs_IG = np.sqrt(np.abs(self.I_G))

        self.temp1 = (self.sign_RG * self.sqrt_abs_RG) /self.sqrt_pdfs
        self.temp2 = (self.sign_IG * self.sqrt_abs_IG) / self.sqrt_pdfs
        self.temp3 = - self.temp2
        self.concat_temp = np.hstack([self.temp1, self.temp1, self.temp3, self.temp2])


    def generate_fmap(self, m_use, embeds, isQuery): #, isQuery=False): 
        """
            Given some value of T, limit, a,b
            Fetch/generate prob samples
            compute map and return 
        """
        embeds_rep = np.repeat(embeds,m_use,axis=-1)
        #print(embeds_rep.dtype) 
        # print(f"embeds.shape {embeds.shape} embeds_rep.shape {embeds_rep.shape} ws.shape {self.ws.shape}")   
        thetas = embeds_rep*self.ws
    
        cos_theta_by_sqrt_pdf = np.cos(thetas) / self.sqrt_pdfs
        sin_theta_by_sqrt_pdf = np.sin(thetas) / self.sqrt_pdfs
        if isQuery:
            fmap1 = self.sign_RG * self.sqrt_abs_RG * cos_theta_by_sqrt_pdf 
            fmap2 = self.sign_RG * self.sqrt_abs_RG * sin_theta_by_sqrt_pdf
            fmap3 = - self.sign_IG * self.sqrt_abs_IG * sin_theta_by_sqrt_pdf
            fmap4 = self.sign_IG * self.sqrt_abs_IG * cos_theta_by_sqrt_pdf
        else:
            fmap1 =  self.sqrt_abs_RG * cos_theta_by_sqrt_pdf
            fmap2 =  self.sqrt_abs_RG * sin_theta_by_sqrt_pdf
            fmap3 =  self.sqrt_abs_IG * cos_theta_by_sqrt_pdf
            fmap4 =  self.sqrt_abs_IG * sin_theta_by_sqrt_pdf
        
        fmaps = np.hstack([fmap1,fmap2,fmap3,fmap4])#.numpy()
      
        return fmaps

    def gen_ghash_fmaps(self, embeds, isQuery): # isQuery=False):
        """
            transpose do BxNxD to BxDxN
            sort on D
            fetch fmaps of shape BxDx(4N\omega)
        """
        transposed_input = np.moveaxis(embeds,-2,-1)
        sorted_input = np.sort(transposed_input,axis=-1)
        flattened_input = sorted_input.reshape(-1, sorted_input.shape[-1])
        output = self.generate_fmap(self.m_use,flattened_input, isQuery)
        reshaped_output = output.reshape(sorted_input.shape[0],\
                                                         sorted_input.shape[1],\
                                                         output.shape[-1])
        return reshaped_output

    def fetch_fmaps_for_embeds(self, embeds, isQuery=False): 
        if self.rel_mode =="sub_iso":
            return self.gen_ghash_fmaps(embeds, isQuery)
        elif self.rel_mode == "ged":
            # [  \phi_q (embed),  \phi_q (-embed) ]
            fmaps_del = self.gen_ghash_fmaps(embeds, isQuery) #associated cost is 1 
            fmaps_add = self.gen_ghash_fmaps(-embeds, isQuery) # associated cost is 1 
            fmaps = np.concatenate([fmaps_del, fmaps_add], axis=2)
            
            return fmaps/self.normalizing_const
        elif self.rel_mode == "uneq_ged":
            # [ sqrt(del_cost) \phi_q (embed), sqrt(add_cost) \phi_q (-embed) ]
            fmaps_del = self.gen_ghash_fmaps(embeds, isQuery) #associated cost is sqrt(del_cost)
            fmaps_add = self.gen_ghash_fmaps(-embeds, isQuery) # associated cost is sqrt(add_cost)
            fmaps = np.concatenate([self.sqrt_del_cost*fmaps_del, self.sqrt_add_cost*fmaps_add], axis=2)
            # if isQuery:
            #     # [ sqrt(del_cost) \phi_q (embed), sqrt(add_cost) \phi_q (-embed) ]
            #     fmaps_del_q =  self.gen_ghash_fmaps(embeds, True)
            #     fmaps_add_x =  self.gen_ghash_fmaps(-embeds, True)
            #     fmaps = np.concatenate([self.sqrt_del_cost*fmaps_del_q, self.sqrt_add_cost*fmaps_add_x], axis=2)
            # else:
            #     # [ sqrt(del_cost) \phi_q (embed), sqrt(add_cost) \phi_x (-embed) ]
            #     fmaps_del_x = self.gen_ghash_fmaps(embeds, False)
            #     fmaps_add_q = self.gen_ghash_fmaps(-embeds, False)
            #     fmaps = np.concatenate([self.sqrt_del_cost*fmaps_del_x, self.sqrt_add_cost*fmaps_add_q], axis=2)
            return fmaps/self.normalizing_const
        else:
            raise NotImplementedError("Invalid rel_mode")
    


    def fetch_RH_hashcodes(self, embeds, isQuery, qid=None):
        """
            Embeds: (num_items, max_set_size, embed_dim)
        """
        batch_sz  = 5000
        #Writing split manually to ensure correctness
        batches = []
        for i in range(0, embeds.shape[0],batch_sz):
            batches.append(embeds[i:i+batch_sz])
        assert sum([item.shape[0] for item in batches]) == embeds.shape[0]

        hcode_list = []
        for batch_item in batches:
            # transposed_input = np.moveaxis(batch_item,-2,-1)
            # sorted_input = np.sort(transposed_input,axis=-1)
            # flattened_input = sorted_input.reshape(-1, sorted_input.shape[-1])
            # output = self.generate_fmap(self.m_use,flattened_input, isQuery)
            # projections = output@self.gauss_hplanes_cos
            fmaps = self.fetch_fmaps_for_embeds( batch_item, isQuery)
            flattened_fmaps = fmaps.reshape(-1, fmaps.shape[-1])
            projections = flattened_fmaps@self.gauss_hplanes_cos
            #per-corpus projections
            unflattened_projections = projections.reshape(fmaps.shape[0],\
                                                         fmaps.shape[1],\
                                                         projections.shape[-1])
            hcode_list.append(np.sign(unflattened_projections))

        hashcodes = np.vstack(hcode_list)

        return hashcodes


              
              


    def index_corpus(self, corpus_embeds):
        s = time.time()
        self.corpus_embeds = corpus_embeds
        self.num_corpus_items, max_set_size, embed_dim = corpus_embeds.shape
        # self.embed_dim was set in the parent class, make sure it is consistent
        assert embed_dim == self.embed_dim
        assert max_set_size == self.max_set_size
        

        
        self.corpus_hashcodes = self.fetch_RH_hashcodes(self.corpus_embeds,isQuery=False)
        assert(self.corpus_embeds.shape[0] == self.corpus_hashcodes.shape[0], f"Size mismatch {self.corpus_embeds.shape[0]} {self.corpus_hashcodes.shape[0]}")
        #generates self.hashcode_mat (containing +1/-1, used for bucketing)
        self.hashcode_mat = self.preprocess_hashcodes(self.corpus_hashcodes)
        
        #Assigns corpus items to buckets in each of the tables
        #generates dict self.all_hash_tables containing bucketId:courpusItemIDs
        self.bucketify()
        logger.info(f"Corpus indexed. Time taken {time.time()-s:.3f} sec")



    def preprocess_hashcodes(self,all_hashcodes): 
        all_hashcodes = np.sign(all_hashcodes)
        #edge case
        if (np.sign(all_hashcodes)==0).any(): 
            logger.info("Hashcode had 0 bits. replacing all with 1")
            all_hashcodes[all_hashcodes==0]=1
        return all_hashcodes

    def assign_bucket(self,function_id,node_hash_code):
        func = self.hash_functions[function_id]
        # convert sequence of -1 and 1 to binary by replacing -1 s to 0
        binary_id = np.take(node_hash_code,func)
        binary_id[binary_id<0] = 0
        # bucket_id = (self.powers_of_two@node_hash_code).astype(self.powers_of_two.dtype) # MISTAKE
        bucket_id = (self.powers_of_two@binary_id).astype(self.powers_of_two.dtype)  
        return bucket_id

    def bucketify(self): 
        """
          For all hash functions: x
            Loop over all corpus items
              Assign corpus item to bucket in hash table corr. to hash function 
        """ 
        self.all_hash_tables = []
        # There is one hashcode per dimension for every corpus item
        for dim_id in range(self.num_hash_tables): 
            hash_table = defaultdict(list)#{}
            # for idx in range(2**self.subset_size): 
            #    hash_table[idx] = []
            for item in range(self.num_corpus_items):
                hash_table[self.assign_bucket(dim_id, self.hashcode_mat[item][dim_id])].append(item)
            self.all_hash_tables.append(hash_table)
            
    def pretty_print_hash_tables(self,topk): 
        """
            I've found this function useful to visualize corpus distribution across buckets
        """
        for table_id in range(self.num_hash_tables): 
            len_list = sorted([len(v) for _,v  in self.all_hash_tables[table_id].items()])[::-1] [:topk]
            # len_list = sorted([len(self.all_hash_tables[table_id][bucket_id]) for bucket_id in range(2**self.hcode_dim)])[::-1] [:topk]
            len_list_str = [str(i) for i in len_list]
            lens = '|'.join(len_list_str)
            print(lens)


    def heapify(self, q_embed, candidate_list, K):
        """
            use q_embed , candidate_list, corpus_embeds to fetch top K items
        """
        raise NotImplementedError("For now decided to do scoring/heapification separately")
       


    # INFO: minor difference from the function in lsh_base.py
    def retrieve(self, q_embed,  K, no_bucket=False, return_candidate_list=False, qid=None): 
        """
            Input : query_embed : to compute actual scores/distances
                      shape is (max_set_size*d)
                    q_masked_feats : specific to the sinkhorn scoring function
            Input : K : top K similar items to return
            Input: return_candidate_list : if True, return all corpus ids, do not bother with scoring, heaping, etc.
            Output : top K items, time taken for retrieval, accuracy? 

            given query and a number k, find the top k closest corpus items 
            loop over al hash_tables: 
              map query to corr bucket: 
                compute Asymmetric similarity between query and each corpus item in bucket and update min heap
        """
        # Given input query_embed: generate query_hashcode : to ID query bucket 
        start_hashcode_gen = {}
        end_hashcode_gen = {}
        if no_bucket:
            for tm in self.timing_methods:
                start_hashcode_gen[tm] = 0
                end_hashcode_gen[tm] = 0
        else:               
            for tm in self.timing_methods:
                start_hashcode_gen[tm] = self.timing_funcs[tm]()
                
            q_hashcode =  self.preprocess_hashcodes(self.fetch_RH_hashcodes(q_embed,isQuery=True,qid=qid)).squeeze()

            for tm in self.timing_methods:
                end_hashcode_gen[tm] = self.timing_funcs[tm]()

        
        start_candidate_list_gen = {}
        end_candidate_list_gen = {}
        if no_bucket:
            for tm in self.timing_methods:
                start_candidate_list_gen[tm] = self.timing_funcs[tm]() 
            #We consider all corpus items  
            candidate_list = list(range(self.num_corpus_items))
            for tm in self.timing_methods:
                end_candidate_list_gen[tm] = self.timing_funcs[tm]()
        else:
            for tm in self.timing_methods:
                start_candidate_list_gen[tm] = self.timing_funcs[tm]()
            #We use q hashcode to identify buckets, and take union of corpus items into candidate set
            candidate_list = []
            # NOTE: There is one hashcode per dimension for every corpus item
            for dim_id in range(self.num_hash_tables): 
                #identify bucket 
                bucket_id = self.assign_bucket(dim_id, q_hashcode[dim_id])
                candidate_list.extend(self.all_hash_tables[dim_id][bucket_id])

            #remove duplicates from candidate_list
            candidate_list = list(set(candidate_list))
            for tm in self.timing_methods:
                end_candidate_list_gen[tm] = self.timing_funcs[tm]()

            if self.DEBUG:
                print("No. of candidates found", len(candidate_list))
        
        if return_candidate_list: 
            return candidate_list

        scores, corpus_ids, time_dict = self.heapify (q_embed,candidate_list, K)

        for tm in self.timing_methods:
            time_dict[tm]['candidate_list_gen_time'] = end_candidate_list_gen[tm] - start_candidate_list_gen[tm]  
            time_dict[tm]['hashcode_gen_time'] = end_hashcode_gen[tm] - start_hashcode_gen[tm]
        return len(candidate_list),  scores,corpus_ids, time_dict