from lsh.ghash2 import Ghash2
from loguru import logger
import numpy as np
import os
import pickle


class Ghash2_Trained(Ghash2):
    def __init__(self,conf):
        super(Ghash2_Trained, self).__init__(conf)
        self.gauss_hplanes_cos = None #set to None for safety
        self.init_hcodes()

    # def init_fmaps(self):
    #     # Not required here. We directly use trained hcodes
    #     pass
    
    def init_fmaps(self):
        tmpVal = self.conf.hashing.name
        self.conf.hashing.name = "Ghash2"
        # NOTE: Below 3 lines should be same as __main__ function in train_fmaps.py
        tmp_prefix_str = "" 
        if self.conf.dataset.rel_mode =="sub_iso":
            tmp_prefix_str = ""
        elif self.conf.dataset.rel_mode =="ged":
            tmp_prefix_str = "G" + ","
        elif self.conf.dataset.rel_mode =="uneq_ged":
            tmp_prefix_str = "UG" + ","
        else:
            raise ValueError(f" rel_mode {self.conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")
        temp_IN_ARCH = "L" +  "".join([f"RL_{dim}_" for dim in self.conf.fmap_training.hidden_layers])
        hashing_config_name_removal_set = {'device', 'embed_dim', 'subset_size', 'classPath','subset_type'}
        hashing_conf_str = ",".join("{}{}".format(*i) for i in self.conf.hashing.items() if (i[0] not in hashing_config_name_removal_set))
        fmap_training_config_name_removal_set = {'model_name', 'classPath', 'device', 'hidden_layers'}
        fmap_training_conf_str = ",".join("{}{}".format(*i) for i in self.conf.fmap_training.items() if (i[0] not in fmap_training_config_name_removal_set))
        curr_task = tmp_prefix_str + self.conf.dataset.name + "," + hashing_conf_str + "," + fmap_training_conf_str + ","+ temp_IN_ARCH    
    
        #checking existence of dumped trained fmaps
        pathname =  self.conf.base_dir  + "allPklDumps/fmapPickles/"+ curr_task +"_fmap_mat.pkl"

        assert( os.path.exists(pathname ), print(pathname))
        
        all_fmaps = pickle.load(open(pathname, "rb"))
        self.corpus_fmaps = all_fmaps['corpus'].cpu().numpy()
        # Maintaining a copy of all fmaps for future use
        self.query_fmaps_all = all_fmaps['query']
        for key in self.query_fmaps_all.keys():
            self.query_fmaps_all[key] = self.query_fmaps_all[key]#.cpu().numpy()
        # IMP: We are using test set fmaps for query. no checks for now against this
        self.query_fmaps = all_fmaps['query']["test"].cpu().numpy()
        logger.info(f"Loaded fmaps from {pathname}. Initialized corpus_fmaps and  >>>> TEST <<<<<  query_fmaps")
        self.conf.hashing.name = tmpVal
        
    
    def init_hcodes(self):
        # NOTE: Below  lines should be same as  in __main__ of train_ghash_hashcode.py
        tmp_prefix_str = "" 
        if self.conf.dataset.rel_mode =="sub_iso":
            tmp_prefix_str = ""
        elif self.conf.dataset.rel_mode =="ged":
            tmp_prefix_str = "G" + ","
        elif self.conf.dataset.rel_mode =="uneq_ged":
            tmp_prefix_str = "UG" + ","
        else:
            raise ValueError(f" rel_mode {self.conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")

        temp_IN_ARCH = "L" +  "".join([f"RL_{dim}_" for dim in self.conf.hashcode_training.hidden_layers])
        hashing_config_name_removal_set = {'device', 'embed_dim', 'subset_size', 'classPath','subset_type'}
        hashing_conf_str = ",".join("{}{}".format(*i) for i in self.conf.hashing.items() if (i[0] not in hashing_config_name_removal_set))
        hashcode_training_config_name_removal_set = {'model_name', 'classPath', 'device', 'hidden_layers'}
        hashcode_training_conf_str = ",".join("{}{}".format(*i) for i in self.conf.hashcode_training.items() if (i[0] not in hashcode_training_config_name_removal_set))
        curr_task = tmp_prefix_str + self.conf.dataset.name + "," + hashing_conf_str + "," + hashcode_training_conf_str + ","+ temp_IN_ARCH

        fmap_IN_ARCH = "L" +  "".join([f"RL_{dim}_" for dim in self.conf.fmap_training.hidden_layers])
        fmap_training_config_name_removal_set = {'model_name', 'classPath', 'device', 'hidden_layers'}
        fmap_training_conf_str = ",".join("{}{}".format(*i) for i in self.conf.fmap_training.items() if (i[0] not in fmap_training_config_name_removal_set))
        #Earlier curr_task gets augmented in this case
        curr_task = curr_task + "," + fmap_training_conf_str + "," + fmap_IN_ARCH
            

        #checking existence of dumped trained hashcodes
        pathname =  self.conf.base_dir  + "allPklDumps/hashcodePickles/"+ curr_task +"_hashcode_mat.pkl"
        short_pname = self.conf.base_dir  + "allPklDumps/hashcodePickles/"+ curr_task + "_hcmat"

        assert( os.path.exists(pathname ), print(pathname))

        # all_hcodes = pickle.load(open(pathname, "rb"))
        if os.path.exists(pathname):
            with open(pathname, 'rb') as f:
                all_hcodes = pickle.load(f)
        else:
            with open(short_pname, 'rb') as f:
                all_hcodes = pickle.load(f)

        
        self.corpus_hcodes = all_hcodes['corpus'].cpu().numpy()
        # Maintaining a copy of all hcodes for future use
        self.query_hcodes_all = all_hcodes['query']
        for key in self.query_hcodes_all.keys():
            self.query_hcodes_all[key] = self.query_hcodes_all[key]#.cpu().numpy()s
        # IMP: We are using test set hcodes for query. no checks for now against this
        self.query_hcodes = all_hcodes['query']["test"].cpu().numpy()
        logger.info(f"Loaded fmaps from {pathname}. Initialized corpus_hcodes and  >>>> TEST <<<<<  query_hcodes")
        


    def fetch_RH_hashcodes(self, embeds, isQuery, qid=None):
        #embeds unused biut still kept for backward compatibility
        if isQuery:
            return self.query_hcodes[qid]
        else:
            return self.corpus_hcodes
