import torch
torch.backends.cuda.matmul.allow_tf32 = False
import torch.nn as nn
from lsh.scoring import pairwise_ranking_loss_similarity
import time
from loguru import logger
from src.embeddings_loader import fetch_graph_corpus_embeddings, fetch_graph_query_embeddings, fetch_graph_ground_truths, fetch_all_info_for_scoring
from lsh.fhash import Fhash
import numpy as np
import tqdm
from lsh.scoring import dot_sim, hinge_sim, pairwise_cosine_sim 
import os
from utils.training_utils import EarlyStoppingModule
import random
import pickle
from sklearn.utils import shuffle
from omegaconf import OmegaConf
from lsh.ghash import Ghash
from lsh.ghash2 import Ghash2

import itertools

from utils.model_utils import  nanl_fast_inference_without_model


class GhashHashCodeTrainer_L(nn.Module):
    """
    """
    def __init__(self, conf):
        super(GhashHashCodeTrainer_L, self).__init__()
        self.LOSS_TYPE = conf.hashcode_training.LOSS_TYPE
        self.FENCE_LAMBDA = conf.hashcode_training.FENCE
        self.DECORR_LAMBDA = conf.hashcode_training.DECORR
        self.C1_LAMBDA = conf.hashcode_training.C1
        self.QA_MARGIN = conf.hashcode_training.QA_MARGIN
        self.hashing_name = conf.hashing.name
        self.num_dim = conf.dataset.embed_dim
        assert "-L" in conf.hashcode_training.LOSS_TYPE # TODO: sad hack. will fix later
        assert conf.hashcode_training.hidden_layers == [] #Supports only one layer for now
        assert self.hashing_name in ["Ghash_Trained", "Ghash2_Trained"] , print(f"You should not be here with {self.hashing_name}. This is for trained hashcode variants.")
        
        h0 = conf.fmap_training.tr_fmap_dim
        h1 = conf.hashing.hcode_dim

        # create a list of linear layers
        for d in range(self.num_dim):
            lin = torch.nn.Linear(h0, h1)
            setattr(self, f"init_net_{d}", lin)


        self.tanh  = nn.Tanh()
        self.TANH_TEMP = conf.hashcode_training.TANH_TEMP     

    def forward(self, fmaps):
        """
            :param  Fmaps
            :return  Hcodes
        """
        assert self.num_dim == fmaps.shape[-2]
        code_list = []
        for d in range(self.num_dim):
            dim_code =  getattr(self, f"init_net_{d}")(fmaps[:,d,:])
            code_list.append(dim_code)
        # print([x.shape for x in code_list]) 
        # code = torch.stack(code_list,dim=-1)   # LLM generated - human propagated error :/
        code = torch.stack(code_list,dim=1)
        return self.tanh(self.TANH_TEMP * code)

    def computeLoss(self, cfmaps, qfmaps, targets):
        if self.LOSS_TYPE == "query_agnostic-L":
            """
                Note that here all query and corpus embeddings are sent in one chunk
                naming cfmaps us slightly misleading -- it contains both query and corpus representationas
                Thhis was used in PermGNN
            """
            all_hcodes = self.forward(cfmaps)
            assert self.num_dim == all_hcodes.shape[1], print(f"all_hcodes shape: {all_hcodes.shape} should be Batch_size x num_dim x hcode_dim")
            all_dim_loss = []
            all_dim_bit_balance_loss = []
            all_dim_decorrelation_loss = []
            all_dim_fence_sitting_loss = []
            for d in range(all_hcodes.shape[1]): #hcodes shape is Batch_size x max_node_set_size x hcode_dim
                dim_hcodes = all_hcodes[:,d,:]
                bit_balance_loss = torch.sum(torch.abs(torch.sum(dim_hcodes,dim=0)))/(dim_hcodes.shape[0]*dim_hcodes.shape[1])
                decorrelation_loss = torch.abs(torch.mean((dim_hcodes.T@dim_hcodes).fill_diagonal_(0)))
                fence_sitting_loss =  torch.norm(dim_hcodes.abs()-1, p=1)/ (dim_hcodes.shape[0]*dim_hcodes.shape[1])
                loss = self.FENCE_LAMBDA * fence_sitting_loss +\
                        self.DECORR_LAMBDA * decorrelation_loss+\
                        (1-self.FENCE_LAMBDA-self.DECORR_LAMBDA) * bit_balance_loss
                all_dim_loss.append(loss)
                all_dim_bit_balance_loss.append(bit_balance_loss)
                all_dim_decorrelation_loss.append(decorrelation_loss)
                all_dim_fence_sitting_loss.append(fence_sitting_loss)
            return torch.stack(all_dim_loss).mean(), torch.stack(all_dim_bit_balance_loss).mean(), torch.stack(all_dim_decorrelation_loss).mean(), torch.stack(all_dim_fence_sitting_loss).mean()

        elif self.LOSS_TYPE == "query_aware-L":
            q_hcodes = self.forward(qfmaps)
            c_hcodes = self.forward(cfmaps)
            all_hcodes = torch.cat([q_hcodes,c_hcodes])
            preds = (q_hcodes*c_hcodes).sum(-1)
            
            all_dim_loss = []
            all_dim_bit_balance_loss = []
            all_dim_ranking_loss = []
            all_dim_fence_sitting_loss = []
            
            for d in range(all_hcodes.shape[1]): #hcodes shape is Batch_size x max_node_set_size x hcode_dim
                dim_all_hcodes = all_hcodes[:,d,:]
                dim_preds = preds[:,d]
            
                fence_sitting_loss =  torch.norm(dim_all_hcodes.abs()-1, p=1)/ (dim_all_hcodes.shape[0]*dim_all_hcodes.shape[1])
                bit_balance_loss = torch.sum(torch.abs(torch.sum(dim_all_hcodes,dim=0)))/(dim_all_hcodes.shape[0]*dim_all_hcodes.shape[1])
                
                predPos = dim_preds[targets>0.5]
                predNeg = dim_preds[targets<0.5]

                #ranking_loss = pairwise_ranking_loss_similarity(predPos.unsqueeze(1),predNeg.unsqueeze(1), 1)
                ranking_loss = pairwise_ranking_loss_similarity(predPos.unsqueeze(1),predNeg.unsqueeze(1), self.QA_MARGIN)

                loss = self.FENCE_LAMBDA * fence_sitting_loss +\
                    self.C1_LAMBDA * ranking_loss+\
                    (1-self.FENCE_LAMBDA-self.C1_LAMBDA) * bit_balance_loss
                    
                all_dim_loss.append(loss)
                all_dim_bit_balance_loss.append(bit_balance_loss)
                all_dim_ranking_loss.append(ranking_loss)
                all_dim_fence_sitting_loss.append(fence_sitting_loss)
            return torch.stack(all_dim_loss).mean(), torch.stack(all_dim_bit_balance_loss).mean(), torch.stack(all_dim_ranking_loss).mean(), torch.stack(all_dim_fence_sitting_loss).mean()

        else:
            assert False, print(f"Unknown loss type {self.LOSS_TYPE}")




class GhashHashCodeTrainer(nn.Module):
    """
    """
    def __init__(self, conf):
        super(GhashHashCodeTrainer, self).__init__()
        self.LOSS_TYPE = conf.hashcode_training.LOSS_TYPE
        self.FENCE_LAMBDA = conf.hashcode_training.FENCE
        self.DECORR_LAMBDA = conf.hashcode_training.DECORR
        self.C1_LAMBDA = conf.hashcode_training.C1
        self.QA_MARGIN = conf.hashcode_training.QA_MARGIN
        self.hashing_name = conf.hashing.name

        self.init_net = []

        if self.hashing_name in ["Ghash_Trained", "Ghash2_Trained"]:
            self.inner_hs = [conf.fmap_training.tr_fmap_dim] + conf.hashcode_training.hidden_layers + [conf.hashing.hcode_dim]
        else:
            assert False, print(f"You should not be here with {self.hashing_name}. This is for trained hashcode variants.")

        for h0, h1 in zip(self.inner_hs, self.inner_hs[1:]):
                lin = torch.nn.Linear(h0, h1)
                self.init_net.append(lin)
                self.init_net.append(torch.nn.ReLU())
        self.init_net.pop() # pop the last relu 
        self.init_net = torch.nn.Sequential(*self.init_net)
        self.tanh  = nn.Tanh()
        self.TANH_TEMP = conf.hashcode_training.TANH_TEMP     

    def forward(self, fmaps):
        """
            :param  Fmaps
            :return  Hcodes
        """
        code = self.init_net(fmaps)
        return self.tanh(self.TANH_TEMP * code)

    def computeLoss(self, cfmaps, qfmaps, targets):
        if self.LOSS_TYPE == "query_agnostic":
            """
                Note that here all query and corpus embeddings are sent in one chunk
                naming cfmaps us slightly misleading -- it contains both query and corpus representationas
                Thhis was used in PermGNN
            """
            all_hcodes = self.forward(cfmaps)
            all_dim_loss = []
            all_dim_bit_balance_loss = []
            all_dim_decorrelation_loss = []
            all_dim_fence_sitting_loss = []
            for d in range(all_hcodes.shape[1]): #hcodes shape is Batch_size x max_node_set_size x hcode_dim
                dim_hcodes = all_hcodes[:,d,:]
                bit_balance_loss = torch.sum(torch.abs(torch.sum(dim_hcodes,dim=0)))/(dim_hcodes.shape[0]*dim_hcodes.shape[1])
                decorrelation_loss = torch.abs(torch.mean((dim_hcodes.T@dim_hcodes).fill_diagonal_(0)))
                fence_sitting_loss =  torch.norm(dim_hcodes.abs()-1, p=1)/ (dim_hcodes.shape[0]*dim_hcodes.shape[1])
                loss = self.FENCE_LAMBDA * fence_sitting_loss +\
                        self.DECORR_LAMBDA * decorrelation_loss+\
                        (1-self.FENCE_LAMBDA-self.DECORR_LAMBDA) * bit_balance_loss
                all_dim_loss.append(loss)
                all_dim_bit_balance_loss.append(bit_balance_loss)
                all_dim_decorrelation_loss.append(decorrelation_loss)
                all_dim_fence_sitting_loss.append(fence_sitting_loss)
            return torch.stack(all_dim_loss).mean(), torch.stack(all_dim_bit_balance_loss).mean(), torch.stack(all_dim_decorrelation_loss).mean(), torch.stack(all_dim_fence_sitting_loss).mean()

        elif self.LOSS_TYPE  == "query_aware":
            q_hcodes = self.forward(qfmaps)
            c_hcodes = self.forward(cfmaps)
            all_hcodes = torch.cat([q_hcodes,c_hcodes])
            preds = (q_hcodes*c_hcodes).sum(-1)
            
            all_dim_loss = []
            all_dim_bit_balance_loss = []
            all_dim_ranking_loss = []
            all_dim_fence_sitting_loss = []
            
            for d in range(all_hcodes.shape[1]): #hcodes shape is Batch_size x max_node_set_size x hcode_dim
                dim_all_hcodes = all_hcodes[:,d,:]
                dim_preds = preds[:,d]
            
                fence_sitting_loss =  torch.norm(dim_all_hcodes.abs()-1, p=1)/ (dim_all_hcodes.shape[0]*dim_all_hcodes.shape[1])
                bit_balance_loss = torch.sum(torch.abs(torch.sum(dim_all_hcodes,dim=0)))/(dim_all_hcodes.shape[0]*dim_all_hcodes.shape[1])
                
                predPos = dim_preds[targets>0.5]
                predNeg = dim_preds[targets<0.5]

                #ranking_loss = pairwise_ranking_loss_similarity(predPos.unsqueeze(1),predNeg.unsqueeze(1), 1)
                ranking_loss = pairwise_ranking_loss_similarity(predPos.unsqueeze(1),predNeg.unsqueeze(1), self.QA_MARGIN)

                loss = self.FENCE_LAMBDA * fence_sitting_loss +\
                    self.C1_LAMBDA * ranking_loss+\
                    (1-self.FENCE_LAMBDA-self.C1_LAMBDA) * bit_balance_loss
                    
                all_dim_loss.append(loss)
                all_dim_bit_balance_loss.append(bit_balance_loss)
                all_dim_ranking_loss.append(ranking_loss)
                all_dim_fence_sitting_loss.append(fence_sitting_loss)
            return torch.stack(all_dim_loss).mean(), torch.stack(all_dim_bit_balance_loss).mean(), torch.stack(all_dim_ranking_loss).mean(), torch.stack(all_dim_fence_sitting_loss).mean()

        else:
            assert False, print(f"Unknown loss type {self.LOSS_TYPE}")
            
            
class GhashHashcodeDataLoader(object):
    def __init__(self, conf):
        self.device = conf.hashcode_training.device
        self.LOSS_TYPE = conf.hashcode_training.LOSS_TYPE
        self.BATCH_SIZE = 1024 #hardcoded for now
    
        if conf.model.name == "NANL":
            self.scoring_function = nanl_fast_inference_without_model
        else:
            raise NotImplementedError
        self.hashing_name = conf.hashing.name
        
        corpus_embeds_fetch_start = time.time()
        self.corpus_embeds = fetch_graph_corpus_embeddings(conf)
        if isinstance(self.corpus_embeds, np.ndarray):
            self.corpus_embeds = torch.from_numpy(self.corpus_embeds).to(self.device)
        corpus_embeds_fetch_time = time.time() - corpus_embeds_fetch_start
        logger.info(f"Corpus embeds shape: {self.corpus_embeds.shape}, time={corpus_embeds_fetch_time}")
       
        if self.hashing_name == "Ghash_Trained":
            conf.hashing.name = "Ghash"
            temp_val = conf.hashing.FUNC
            conf.hashing.FUNC = "None"  #Weird: "" was treated as None in filename 
            self.lsh = Ghash(conf)
            conf.hashing.FUNC = temp_val
            conf.hashing.name = "Ghash_Trained"
            corpusfmaps_start_time = time.time()

            self.corpus_fmaps = self.lsh.corpus_fmaps
            if isinstance(self.corpus_fmaps, np.ndarray):
                self.corpus_fmaps = torch.from_numpy(self.corpus_fmaps).to(self.device)

        if self.hashing_name == "Ghash2_Trained":
            conf.hashing.name = "Ghash2"
            temp_val = conf.hashing.FUNC
            conf.hashing.FUNC = "None"  #Weird: "" was treated as None in filename 
            self.lsh = Ghash2(conf)
            conf.hashing.FUNC = temp_val
            conf.hashing.name = "Ghash2_Trained"
            corpusfmaps_start_time = time.time()

            self.corpus_fmaps = self.lsh.corpus_fmaps
            if isinstance(self.corpus_fmaps, np.ndarray):
                self.corpus_fmaps = torch.from_numpy(self.corpus_fmaps).to(self.device)



        self.query_embeds  = {}
        if self.hashing_name in ["Ghash_Trained","Ghash2_Trained"]:
            self.query_fmaps  = {}
        self.ground_truth = {}
        # self.list_pos = {}
        # self.list_neg = {} 
        self.list_total_arranged_per_query = {}
        # self.labels_total_arranged_per_query = {}
        
        for mode in ["train", "val", "test"]:
            self.query_embeds[mode] = fetch_graph_query_embeddings(conf, mode)
            if isinstance(self.query_embeds[mode], np.ndarray):
                self.query_embeds[mode] = torch.from_numpy(self.query_embeds[mode]).to(self.device)
            if self.hashing_name in ["Ghash_Trained","Ghash2_Trained"]:
                self.query_fmaps[mode]= self.lsh.query_fmaps_all[mode]
                if isinstance(self.query_fmaps[mode], np.ndarray):
                    self.query_fmaps[mode] = torch.from_numpy(self.query_fmaps[mode]).to(self.device)


            if self.LOSS_TYPE in  ["query_aware", "query_aware-L"]: 
                if self.hashing_name in ["Ghash_Trained","Ghash2_Trained"]:
                    num_pos = int(self.corpus_embeds.shape[0]/(2**conf.hashcode_training.QA_subset_size))
                    gt = {}
                    aux_info = fetch_all_info_for_scoring(conf, mode, conf.model.name, "", conf.dataset.rel_mode)
                    cuda_query_embeds = self.query_embeds[mode].cuda()
                    cuda_corpus_embeds = self.corpus_embeds.cuda()
                    for qidx in range(len(self.query_embeds[mode])):
                        sc = self.scoring_function(cuda_query_embeds[qidx].unsqueeze(0),\
                                        cuda_corpus_embeds,\
                                        aux_info['query_aux_info'][f"{mode}_masked_features_query"][qidx].unsqueeze(0),\
                                        aux_info['corpus_aux_info']['masked_features_corpus'],\
                                        aux_info['corpus_aux_info']['sinkhorn_temp'],\
                                        conf.dataset.rel_mode
                                        ).cpu().numpy().squeeze()
                        pos_cids = np.argsort(sc[qidx])[::-1][:num_pos].tolist()
                        gt[qidx] = pos_cids
                    self.ground_truth[mode] = gt
                else:
                    raise NotImplementedError
            else:
                self.ground_truth[mode] = fetch_graph_ground_truths(conf, mode)
                
            gt_indicator = torch.zeros((self.query_embeds[mode].shape[0], self.corpus_embeds.shape[0]), device=self.device)
            for q in range(self.query_embeds[mode].shape[0]):
                gt_indicator[q][self.ground_truth[mode][q]] = 1.0

            self.list_total_arranged_per_query[mode] = []
            
            self.list_total_arranged_per_query[mode] = list(zip(itertools.product(range(self.query_embeds[mode].shape[0]), range(self.corpus_embeds.shape[0])),gt_indicator.flatten().tolist()))

            
        logger.info('Query embeds fetched and fmaps generated.')
        logger.info('Ground truth fetched.')
        self.preprocess_create_per_query_batches()


    def create_fmap_batches(self,mode):
        if self.hashing_name in ["Ghash_Trained","Ghash2_Trained"]:
            all_fmaps = torch.cat([self.query_fmaps[mode], self.corpus_fmaps])
        if mode == "train":
            all_fmaps = all_fmaps[torch.randperm(all_fmaps.shape[0])]
        
        self.batches = list(all_fmaps.split(self.BATCH_SIZE))
        self.num_batches = len(self.batches)
        return self.num_batches

    def fetch_fmap_batched_data_by_id(self,i):
        assert(i < self.num_batches)  
        return self.batches[i]

    def preprocess_create_per_query_batches(self):
        split_len  = self.corpus_embeds.shape[0]
        print("In preprocess_create_per_query_batches")
        self.per_query_batches={} 
        for mode in ["train", "val", "test"]:
            self.per_query_batches[mode]={}
            whole_list = self.list_total_arranged_per_query[mode]
            batches = [whole_list[i:i + split_len] for i in range(0, len(whole_list), split_len)]
            alists = []
            blists = []
            scores = []

            for btch in batches:
                btch_np = np.array(btch, dtype=object)
                scores.append(torch.tensor(btch_np[:,1].tolist()).cuda())
                temp = np.array(btch_np[:,0].tolist())
                alists.append(temp[:,0].tolist())
                blists.append(temp[:,1].tolist())
                
            self.per_query_batches[mode]['alists'] = alists
            self.per_query_batches[mode]['blists'] = blists
            self.per_query_batches[mode]['scores'] = scores



    def fetch_batched_data_by_id_optimized(self,i):
        """             
        """
        assert(i < self.num_batches)  
        alist = self.alists[i]
        blist = self.blists[i]
        score = self.scores[i]
        if self.hashing_name in ["Ghash_Trained","Ghash2_Trained"]:
            query_tensors = self.query_fmaps[self.mode][alist]
            corpus_tensors = self.corpus_fmaps[blist]

        target = score
        return corpus_tensors, query_tensors, target 

    def create_per_query_batches(self,mode):
        """
          create batches as is and return number of batches created
        """
        self.alists = self.per_query_batches[mode]['alists']
        self.blists = self.per_query_batches[mode]['blists']
        self.scores = self.per_query_batches[mode]['scores']

        if mode=="train":
            self.alists,self.blists,self.scores = shuffle(self.alists,self.blists,self.scores)

        self.num_batches = len(self.alists)  
        self.mode = mode

        return self.num_batches

def evaluate_validation_query_aware(model, sampler, mode):
  model.eval()

  total_loss = 0 
  total_bit_balance_loss = 0 
  total_ranking_loss = 0
  total_fence_sitting_loss = 0
  n_batches = sampler.create_per_query_batches(mode)
  for i in tqdm.tqdm(range(n_batches)):
    batch_corpus_tensors, batch_query_tensors, batch_target = sampler.fetch_batched_data_by_id_optimized(i)
    #batch_tensors = sampler.fetch_fmap_batched_data_by_id(i)
    loss,bit_balance_loss,ranking_loss, fence_sitting_loss = model.computeLoss(batch_corpus_tensors, batch_query_tensors, batch_target)
    total_loss = total_loss+loss.item()
    total_bit_balance_loss += bit_balance_loss.item() 
    total_ranking_loss += ranking_loss.item()
    total_fence_sitting_loss += fence_sitting_loss.item()

  return total_loss, total_bit_balance_loss, total_ranking_loss, total_fence_sitting_loss 


def evaluate_validation_query_agnostic(model, sampler, mode):
  model.eval()

  total_loss = 0 
  total_bit_balance_loss = 0 
  total_decorrelation_loss = 0
  total_fence_sitting_loss = 0
  n_batches = sampler.create_fmap_batches(mode)
  for i in tqdm.tqdm(range(n_batches)):
    batch_tensors = sampler.fetch_fmap_batched_data_by_id(i)
    loss,bit_balance_loss,decorrelation_loss, fence_sitting_loss = model.computeLoss(batch_tensors, None, None)
    total_loss = total_loss+loss.item()
    total_bit_balance_loss += bit_balance_loss.item() 
    total_decorrelation_loss += decorrelation_loss.item()
    total_fence_sitting_loss += fence_sitting_loss.item()

  return total_loss, total_bit_balance_loss, total_decorrelation_loss, total_fence_sitting_loss 


def run_hashcode_gen(conf, curr_task):
        train_data = GhashHashcodeDataLoader(conf)
        if "-L" in conf.hashcode_training.LOSS_TYPE:
            model = GhashHashCodeTrainer_L(conf).to(conf.hashcode_training.device)
        else:
            model = GhashHashCodeTrainer(conf).to(conf.hashcode_training.device)
        
        cnt = 0
        for param in model.parameters():
            cnt=cnt+torch.numel(param)
        logger.info(f"no. of params in model: {cnt}")
        
        es = EarlyStoppingModule(conf.base_dir, curr_task, patience=conf.training.patience, logger=logger)

        # optimizer = torch.optim.Adam(model.parameters(),
        #                             lr=conf.training.learning_rate,
        #                             weight_decay=conf.training.weight_decay)
      
        # best_neg_val_loss = 0  # weird naming but more straightforward to understand
        # run = 0
        # while conf.training.run_till_early_stopping and run < conf.training.num_epochs:
        #     start_time = time.time()
        #     if conf.hashcode_training.LOSS_TYPE in ["query_agnostic", "query_agnostic-L"]:
        #         n_batches = train_data.create_fmap_batches(mode="train")
        #     elif conf.hashcode_training.LOSS_TYPE in ["query_aware", "query_aware-L"]:
        #         n_batches = train_data.create_per_query_batches(mode="train")
        #     else:
        #         assert False, print(f"Unknown hashcode training loss type {conf.hashcode_training.LOSS_TYPE}")

        #     epoch_loss =0
        #     epoch_bit_balance_loss = 0 
        #     epoch_decorrelation_loss = 0
        #     epoch_fence_sitting_loss = 0
        #     epoch_ranking_loss = 0

             
        #     for i in tqdm.tqdm(range(n_batches)):
        #         optimizer.zero_grad()
        #         if conf.hashcode_training.LOSS_TYPE in ["query_agnostic", "query_agnostic-L"]:
        #             batch_tensors = train_data.fetch_fmap_batched_data_by_id(i)

        #             loss,bit_balance_loss,decorrelation_loss, fence_sitting_loss  = model.computeLoss(batch_tensors, None,None)
        #             epoch_bit_balance_loss += bit_balance_loss.item() 
        #             epoch_decorrelation_loss += decorrelation_loss.item()
        #             epoch_fence_sitting_loss += fence_sitting_loss.item()
        #         if conf.hashcode_training.LOSS_TYPE in ["query_aware", "query_aware-L"]:
        #             batch_corpus_tensors, batch_query_tensors, batch_target = train_data.fetch_batched_data_by_id_optimized(i)

        #             loss,bit_balance_loss, fence_sitting_loss, ranking_loss  = model.computeLoss(batch_corpus_tensors, batch_query_tensors, batch_target)
        #             epoch_bit_balance_loss += bit_balance_loss.item() 
        #             epoch_fence_sitting_loss += fence_sitting_loss.item()
        #             epoch_ranking_loss += ranking_loss.item()

        #         loss.backward()
        #         optimizer.step()
        #         epoch_loss = epoch_loss + loss.item()   

        #     if conf.hashcode_training.LOSS_TYPE in ["query_agnostic", "query_agnostic-L"]:
        #         logger.info(f"Epoch: {run} loss: {epoch_loss} bit_balance_loss: {epoch_bit_balance_loss} decorrelation_loss: {epoch_decorrelation_loss} fence_sitting_loss: {epoch_fence_sitting_loss} time: {time.time()-start_time}")
        #     if conf.hashcode_training.LOSS_TYPE in ["query_aware", "query_aware-L"]:
        #         logger.info(f"Epoch: {run} loss: {epoch_loss} bit_balance_loss: {epoch_bit_balance_loss} ranking_loss: {epoch_ranking_loss} fence_sitting_loss: {epoch_fence_sitting_loss} time: {time.time()-start_time}")

        #     start_time = time.time()
        #     if conf.hashcode_training.LOSS_TYPE in ["query_aware", "query_aware-L"]:
        #         val_loss,total_bit_balance_loss,total_ranking_loss, total_fence_sitting_loss = evaluate_validation_query_aware(model,train_data, mode="val")
        #         logger.info(f"Epoch: {run} VAL loss: {val_loss} bit_balance_loss: {total_bit_balance_loss} ranking_loss: {total_ranking_loss} fence_sitting_loss: {total_fence_sitting_loss} time: {time.time()-start_time}")
        #     if conf.hashcode_training.LOSS_TYPE in ["query_agnostic", "query_agnostic-L"]:
        #         val_loss,total_bit_balance_loss,total_decorrelation_loss, total_fence_sitting_loss = evaluate_validation_query_agnostic(model,train_data, mode="val")
        #         logger.info(f"Epoch: {run} VAL loss: {val_loss} bit_balance_loss: {total_bit_balance_loss} decorrelation_loss: {total_decorrelation_loss} fence_sitting_loss: {total_fence_sitting_loss} time: {time.time()-start_time}")
        
        #     neg_val_loss = -val_loss

        #     state_dict = {
        #         "model_state_dict": model.state_dict(),
        #         "optim_state_dict": optimizer.state_dict(),
        #         "epoch": run,
        #         "best_neg_val_loss": best_neg_val_loss,
        #         "neg_val_loss": neg_val_loss,
        #         'rng_state': torch.get_rng_state(),
        #         'cuda_rng_state': torch.cuda.get_rng_state(),
        #         'np_rng_state': np.random.get_state(),
        #         'random_state': random.getstate(),
        #         'patience': es.patience,
        #         'best_scores': es.best_scores,
        #         'num_bad_epochs': es.num_bad_epochs,
        #         'should_stop_now': es.should_stop_now,
        #     }

        #     state_dict =  es.check([neg_val_loss], state_dict)
        #     best_neg_val_loss = state_dict["best_neg_val_loss"]

        #     if es.should_stop_now:
        #         break
        #     run+=1
       
        #generate and dump hashcode  pickles
        #IMP: Load best validation model here
        checkpoint = es.load_best_model()
        model.load_state_dict(checkpoint['model_state_dict'])      

        all_hashcodes = {}
        corpus_hashcodes = torch.zeros((train_data.corpus_embeds.shape[0], train_data.lsh.num_hash_tables, conf.hashing.hcode_dim), device=conf.hashcode_training.device)
        bsz = 5000
        for i in tqdm.tqdm(range(0, train_data.corpus_embeds.shape[0],bsz)):
            corpus_hashcodes[i:i+bsz,:] = model.forward(train_data.corpus_fmaps[i:i+bsz,:]).data

        query_hashcodes = {}
        for mode in ["train", "val", "test"]:
            query_hashcodes[mode] =  model.forward(train_data.query_fmaps[mode]).data

        all_hashcodes['query'] = query_hashcodes
        all_hashcodes['corpus'] = corpus_hashcodes
        try:
            with open(pickle_fp, 'wb') as f:
                # pass # Just testing if the file can be opened; no writing yet
                pickle.dump(all_hashcodes, f)
            logger.info(f"Dumping trained hashcode pickle at {pickle_fp}")
        except OSError as e:
            with open(shortened_filepath, 'wb') as f:
                pickle.dump(all_hashcodes, f)
            logger.info(f"Dumping trained hashcode pickle at {shortened_filepath}")
            
            
            
            
if __name__ == "__main__":

    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    hash_conf = OmegaConf.load(f"configs/hash_configs/{cli_conf.hashing.name}.yaml")
    conf = OmegaConf.merge(main_conf, model_conf, data_conf, hash_conf, cli_conf)

    # NOTE: Below  lines should be same as first three in "init_hcodes" function in Ghash_Trained.py
    tmp_prefix_str = "" 
    if cli_conf.dataset.rel_mode =="sub_iso":
        tmp_prefix_str = ""
    elif cli_conf.dataset.rel_mode =="ged":
        tmp_prefix_str = "G" + ","
    elif cli_conf.dataset.rel_mode =="uneq_ged":
        tmp_prefix_str = "UG" + ","
    else:
        raise ValueError(f" rel_mode {cli_conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")
    temp_IN_ARCH = "L" +  "".join([f"RL_{dim}_" for dim in conf.hashcode_training.hidden_layers])
    hashing_config_name_removal_set = {'device', 'embed_dim', 'subset_size', 'classPath','subset_type'}
    hashing_conf_str = ",".join("{}{}".format(*i) for i in conf.hashing.items() if (i[0] not in hashing_config_name_removal_set))
    hashcode_training_config_name_removal_set = {'model_name', 'classPath', 'device', 'hidden_layers'}
    hashcode_training_conf_str = ",".join("{}{}".format(*i) for i in conf.hashcode_training.items() if (i[0] not in hashcode_training_config_name_removal_set))
    curr_task = tmp_prefix_str +  conf.dataset.name + "," + hashing_conf_str + "," + hashcode_training_conf_str + ","+ temp_IN_ARCH

    if conf.hashing.name in ["Ghash_Trained","Ghash2_Trained"]:
        fmap_IN_ARCH = "L" +  "".join([f"RL_{dim}_" for dim in conf.fmap_training.hidden_layers])
        fmap_training_config_name_removal_set = {'model_name', 'classPath', 'device', 'hidden_layers'}
        fmap_training_conf_str = ",".join("{}{}".format(*i) for i in conf.fmap_training.items() if (i[0] not in fmap_training_config_name_removal_set))
        #Earlier curr_task gets augmented in this case
        curr_task = curr_task + "," + fmap_training_conf_str + "," + fmap_IN_ARCH



    logger.info(f"Task name: {curr_task}")
    logger.add(f"{conf.log.dir}/{curr_task}.log")
    logger.info(OmegaConf.to_yaml(conf))

    # Set random seeds
    seed = 4
    random.seed(seed)
    np.random.seed(seed + 1)
    torch.manual_seed(seed + 2)
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False



    pickle_fp = conf.base_dir + "allPklDumps/hashcodePickles/"+ curr_task + "_hashcode_mat.pkl"
    shortened_filepath = conf.base_dir + "allPklDumps/hashcodePickles/"+ curr_task + "_hcmat"
    # if not (os.path.exists(pickle_fp) or os.path.exists(shortened_filepath)):
    if not os.path.exists(pickle_fp):
        run_hashcode_gen(conf, curr_task)
        


# (CUDA_VISIBLE_DEVICES=0 python -m lsh.train_hashcode dataset.name="msweb294" dataset.rel_mode="fhash" hashing.m_use=10  dataset.embed_dim=294   hashing.hcode_dim=64 hashcode_training.hidden_layers=[]   hashing.FUNC="sighinge" hashing.name="RH_Trained" hashing.num_hash_tables=10 training.patience=50 hashcode_training.LOSS_TYPE="query_agnostic" hashcode_training.QA_subset_size=8  hashcode_training.FENCE=0.1 hashcode_training.DECORR=0.1  hashcode_training.QA_MARGIN=1.0) & 

# msweb294,nameRH_Trained,FUNCsighinge,hcode_dim64,num_hash_tables10,m_use10,LOSS_TYPEquery_agnostic,QA_subset_size8,QA_MARGIN1.0,FENCE0.1,DECORR0.8,C10,TANH_TEMP1.0,L
    
# msweb294,nameRH_Trained,FUNCsighinge,hcode_dim64,num_hash_tables10,LOSS_TYPEquery_aware,QA_subset_size8,QA_MARGIN1.0,FENCE0.1,DECORR0,C10.2,TANH_TEMP1.0,L