from lsh.ghash2_raw import Ghash2_Raw
from loguru import logger
from lsh.lsh_utils import fetch_gaussian_hyperplanes
import numpy as np
import os
import pickle


class Ghash2(Ghash2_Raw):
    def __init__(self,conf):
        super(Ghash2, self).__init__(conf)
        self.gauss_hplanes_cos = fetch_gaussian_hyperplanes(self.hcode_dim, self.conf.fmap_training.tr_fmap_dim)
        self.init_fmaps()


    
    def init_fmaps(self):
        # NOTE: Below 3 lines should be same as __main__ function in train_fmaps.py
        tmp_prefix_str = "" 
        if self.conf.dataset.rel_mode =="sub_iso":
            tmp_prefix_str = ""
        elif self.conf.dataset.rel_mode =="ged":
            tmp_prefix_str = "G" + ","
        elif self.conf.dataset.rel_mode =="uneq_ged":
            tmp_prefix_str = "UG" + ","
        else:
            raise ValueError(f" rel_mode {self.conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")
        temp_IN_ARCH = "L" +  "".join([f"RL_{dim}_" for dim in self.conf.fmap_training.hidden_layers])
        hashing_config_name_removal_set = {'device', 'embed_dim', 'subset_size', 'classPath','subset_type'}
        hashing_conf_str = ",".join("{}{}".format(*i) for i in self.conf.hashing.items() if (i[0] not in hashing_config_name_removal_set))
        fmap_training_config_name_removal_set = {'model_name', 'classPath', 'device', 'hidden_layers'}
        fmap_training_conf_str = ",".join("{}{}".format(*i) for i in self.conf.fmap_training.items() if (i[0] not in fmap_training_config_name_removal_set))
        curr_task = tmp_prefix_str + self.conf.dataset.name + "," + hashing_conf_str + "," + fmap_training_conf_str + ","+ temp_IN_ARCH    
    
        #checking existence of dumped trained fmaps
        pathname =  self.conf.base_dir  + "allPklDumps/fmapPickles/"+ curr_task +"_fmap_mat.pkl"

        assert( os.path.exists(pathname ), print(pathname))
        
        all_fmaps = pickle.load(open(pathname, "rb"))
        self.corpus_fmaps = all_fmaps['corpus'].cpu().numpy()
        # Maintaining a copy of all fmaps for future use
        self.query_fmaps_all = all_fmaps['query']
        for key in self.query_fmaps_all.keys():
            self.query_fmaps_all[key] = self.query_fmaps_all[key]#.cpu().numpy()
        # IMP: We are using test set fmaps for query. no checks for now against this
        self.query_fmaps = all_fmaps['query']["test"].cpu().numpy()
        logger.info(f"Loaded fmaps from {pathname}. Initialized corpus_fmaps and  >>>> TEST <<<<<  query_fmaps")
        
    
    def fetch_fmaps(self, isQuery, qid=None):
        if isQuery:
            return self.query_fmaps[qid][None,:,:]
        else:
            return self.corpus_fmaps
        


    def fetch_RH_hashcodes(self, embeds, isQuery, qid=None):
        """
            Embeds: (num_items, max_set_size, embed_dim)
        """
        batch_sz  = 5000
        #Writing split manually to ensure correctness
        batches = []
        for i in range(0, embeds.shape[0],batch_sz):
            batches.append(embeds[i:i+batch_sz])
        assert sum([item.shape[0] for item in batches]) == embeds.shape[0]

        hcode_list = []
        for batch_item in batches:
            # fmaps = self.fetch_fmaps_for_embeds( batch_item, isQuery)
            fmaps = self.fetch_fmaps(isQuery, qid)
            flattened_fmaps = fmaps.reshape(-1, fmaps.shape[-1])
            projections = flattened_fmaps@self.gauss_hplanes_cos
            #per-corpus projections
            unflattened_projections = projections.reshape(fmaps.shape[0],\
                                                         fmaps.shape[1],\
                                                         projections.shape[-1])
            hcode_list.append(np.sign(unflattened_projections))

        hashcodes = np.vstack(hcode_list)

        return hashcodes