from src.embeddings_loader import fetch_graph_corpus_embeddings, fetch_graph_query_embeddings, fetch_graph_ground_truths
from src.embeddings_loader import fetch_all_info_for_scoring
from utils.model_utils import  nanl_fast_inference_without_model, scoring_functions
from utils.utils import *
import torch 
import random
import numpy as np
from loguru import logger
from sklearn.metrics import ndcg_score
import pickle

def set_seed():
  seed = 4
  random.seed(seed)
  np.random.seed(seed + 1)
  torch.manual_seed(seed + 2)
  torch.backends.cudnn.deterministic = False


def run_lsh(conf):
    corpus_embeds = fetch_graph_corpus_embeddings(conf)

    query_embeds = fetch_graph_query_embeddings(conf, "test")
    ground_truth = fetch_graph_ground_truths(conf, "test")
  
    num_qitems = query_embeds.shape[0]
    assert(len(ground_truth.keys()) == num_qitems)

    set_seed()
    #This will init the k hash functions
    lsh = get_class(f"{conf.hashing.classPath}.{conf.hashing.name}")(conf)#.to(conf.hashing.device)

    #This will generate feature maps and index corpus items
    lsh.index_corpus(corpus_embeds)

    all_hashing_info= []
    
    for qid,qemb in enumerate(query_embeds): 
        #reshape qemb to 1*d
        all_hashing_info.append(lsh.retrieve(qemb[None,:],conf.K, no_bucket=False,return_candidate_list=True,qid=qid))


    return all_hashing_info


def run_faiss(conf,device):
    import faiss

    assert conf.hashing.name == "Faiss"
    assert conf.model.name == "NANL" # currently Faiss is only supported for NANL
    
    corpus_embeds = fetch_graph_corpus_embeddings(conf)

    query_embeds = fetch_graph_query_embeddings(conf, "test")
    ground_truth = fetch_graph_ground_truths(conf, "test")
  
    num_qitems = query_embeds.shape[0]
    assert(len(ground_truth.keys()) == num_qitems)

    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
    mname, scoring_fn = ("NANL", "")

    if conf.dataset.rel_mode == "sub_iso":
        aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
    elif conf.dataset.rel_mode == "ged":
        aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    elif conf.dataset.rel_mode == "uneq_ged":
        aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    else:
        raise ValueError(f" rel_mode {conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")


    aux_info_dict = pickle.load(open(aux_fname, "rb"))

    if conf.dataset.rel_mode == "sub_iso":
        test_aux_fname = f"{directory}/test/{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
    elif conf.dataset.rel_mode == "ged":
        test_aux_fname = f"{directory}/test/{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    elif conf.dataset.rel_mode == "uneq_ged":
        test_aux_fname = f"{directory}/test/{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    else:
        raise ValueError(f" rel_mode {conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")
        

    test_aux_info_dict = pickle.load(open(test_aux_fname, "rb"))
    

    query_mask = test_aux_info_dict['test_query_mask'].to(device)
    corpus_mask = aux_info_dict['corpus_mask'].to(device)


    set_seed()

    faiss_instance = get_class(f"{conf.hashing.classPath}.{conf.hashing.name}")(conf)#.to(conf.hashing.device)

    #This will generate feature maps and index corpus items
    faiss_instance.index_corpus(corpus_embeds,corpus_mask)

    all_hashing_info_union = []
    all_hashing_info_inter = []

    if  conf.hashing.faissmetric == "cosine":
        H_q = query_embeds
        faiss.normalize_L2(H_q)
    elif conf.hashing.faissmetric == "l2":
        H_q = query_embeds
    else:
        raise ValueError(f"Invalid metric {conf.hashing.faissmetric}. Use 'cosine' or 'l2'.")
        
    for qid,qemb in enumerate(H_q): 
        #reshape qemb to 1*d
        op = faiss_instance.retrieve(qid, qemb,query_mask[qid], conf.hashing.top_K)
        all_hashing_info_union.append(op[0])
        all_hashing_info_inter.append(op[1])


    return all_hashing_info_union, all_hashing_info_inter



def run_random(conf):

    assert conf.hashing.name == "Random"
    
    query_embeds = fetch_graph_query_embeddings(conf, "test")
    ground_truth = fetch_graph_ground_truths(conf, "test")
  
    num_qitems = query_embeds.shape[0]
    assert(len(ground_truth.keys()) == num_qitems)

    set_seed()

    all_hashing_info = []
    
    for _ in enumerate(query_embeds): 
        all_hashing_info.append( list(random.sample(range(0, conf.dataset.aug_num_cgraphs), conf.hashing.top_K))
                               )
    return all_hashing_info




def run_diskann(conf,device):

    assert conf.hashing.name == "DiskANN"
    assert conf.model.name == "NANL" # currently DISKANN is only supported for NANL
    
    corpus_embeds = fetch_graph_corpus_embeddings(conf)

    query_embeds = fetch_graph_query_embeddings(conf, "test")
    ground_truth = fetch_graph_ground_truths(conf, "test")
  
    num_qitems = query_embeds.shape[0]
    assert(len(ground_truth.keys()) == num_qitems)

    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
    mname, scoring_fn = ("NANL", "")

    if conf.dataset.rel_mode == "sub_iso":
        aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
    elif conf.dataset.rel_mode == "ged":
        aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    elif conf.dataset.rel_mode == "uneq_ged":
        aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    else:
        raise ValueError(f" rel_mode {conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")


    aux_info_dict = pickle.load(open(aux_fname, "rb"))

    if conf.dataset.rel_mode == "sub_iso":
        test_aux_fname = f"{directory}/test/{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
    elif conf.dataset.rel_mode == "ged":
        test_aux_fname = f"{directory}/test/{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    elif conf.dataset.rel_mode == "uneq_ged":
        test_aux_fname = f"{directory}/test/{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    else:
        raise ValueError(f" rel_mode {conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")
        

    test_aux_info_dict = pickle.load(open(test_aux_fname, "rb"))
    

    query_mask = test_aux_info_dict['test_query_mask'].to(device)
    corpus_mask = aux_info_dict['corpus_mask'].to(device)


    set_seed()

    disann_instance = get_class(f"{conf.hashing.classPath}.{conf.hashing.name}")(conf)#.to(conf.hashing.device)

    #This will generate feature maps and index corpus items
    disann_instance.index_corpus(corpus_embeds,corpus_mask)

    all_hashing_info_union = []
    all_hashing_info_inter = []

    
    for qid,qemb in enumerate(query_embeds): 
        #reshape qemb to 1*d
        op = disann_instance.retrieve(qid, qemb,query_mask[qid], conf.hashing.top_K)
        all_hashing_info_union.append(op[0])
        all_hashing_info_inter.append(op[1])


    return all_hashing_info_union, all_hashing_info_inter



def compute_scores_from_cids(conf, mode, model_name, scoring_fn, rel_mode, all_hashing_info,device=None):
    """
        Use conf only to refer to dataset config
        model_name --> NANL/GEN
        scoring_fn -> "", "sigmoid", "asym", "dot", "hinge", "wjac", "cos"
        rel_mode --> sub_iso/ged
        
    """
    all_info = fetch_all_info_for_scoring(conf, mode, model_name, scoring_fn, rel_mode,device)
    all_hash_score_info= []

    for qidx in range(len(all_info['query_embeds'])): 
        retrieved_corpus_id_list_for_query = all_hashing_info[qidx]
        if model_name == "NANL":
            scores = nanl_fast_inference_without_model(all_info['query_embeds'][qidx].unsqueeze(0),\
                                                        all_info['corpus_embeds'][retrieved_corpus_id_list_for_query],\
                                                        all_info['query_aux_info'][f"{mode}_masked_features_query"][qidx].unsqueeze(0),\
                                                        all_info['corpus_aux_info']['masked_features_corpus'][retrieved_corpus_id_list_for_query],\
                                                        all_info['corpus_aux_info']['sinkhorn_temp'], rel_mode)
        elif model_name == "GEN": 
            if scoring_fn == "sighinge":
                raise NotImplementedError
            else:
                scores =  scoring_functions[scoring_fn](all_info['query_embeds'][qidx].unsqueeze(0),\
                                                        all_info['corpus_embeds'][retrieved_corpus_id_list_for_query])
        else:
            raise NotImplementedError
        if len(scores.shape)>1:
            scores_arr = np.array(scores.squeeze(0).cpu())
        else: 
            scores_arr = np.array(scores.cpu())
            

        corpus_ids_arr = np.array(retrieved_corpus_id_list_for_query)
        sorted_ids = np.argsort(-scores_arr)
        sorted_scores = scores_arr[sorted_ids]
        sorted_corpus_ids = corpus_ids_arr[sorted_ids]
        all_hash_score_info.append((len(retrieved_corpus_id_list_for_query),list(sorted_scores), list(sorted_corpus_ids)))
    
    return all_hash_score_info  

    



def compute_eval_scores(hash_info,ground_truth): 
    all_topK_score_10 = []
    num_evals = []
    all_customap_hash = []
    all_ndcg_5 = []
    all_ndcg_10 = []
    all_ndcg_100 = []
    all_ndcg_1000 = []
    all_ndcg_5000 = []
    all_ndcg_10000 = []
    all_ndcg_20000 = []

    num_qitems = len(hash_info)

    for qidx in range(num_qitems):
        all_topK_score_10.append(compute_topK_score(hash_info[qidx][1], 10))
        num_evals.append(hash_info[qidx][0])
        all_customap_hash.append(custom_ap(set(ground_truth[qidx]),hash_info[qidx][2],hash_info[qidx][1],-1, len(set(ground_truth[qidx]))))
        all_ndcg_5.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=5 ))
        all_ndcg_10.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=10 ))
        all_ndcg_100.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=100 ))
        all_ndcg_1000.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=1000 ))
        all_ndcg_5000.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=5000 ))
        all_ndcg_10000.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=10000 ))
        all_ndcg_20000.append(compute_ndcg(ground_truth[qidx],hash_info[qidx][2],hash_info[qidx][1],K=20000 ))

    
    return np.mean(np.array(all_topK_score_10)),\
            np.mean(num_evals), np.mean(all_customap_hash),\
            np.mean(all_ndcg_5), np.mean(all_ndcg_10), np.mean(all_ndcg_100), np.mean(all_ndcg_1000),\
            np.mean(all_ndcg_5000), np.mean(all_ndcg_10000), np.mean(all_ndcg_20000)

def compute_ndcg(ground_truth, pred_cids, pred_scores, K):
    x = np.zeros(100000) #NOTE: hardcoded to max corpus size. Okay as long las max(pred_cids) < 100000
    x[ground_truth]=1
    # handling edge cases
    if len(pred_cids)==0:
        return 0
    elif len(pred_cids)==1:
        return x[pred_cids][0]  # if only one item, return corresponding binary label as  ndcg score
    #another edge case
    valid_pos = ~np.isnan(pred_scores)
    pred_cids = np.array(pred_cids)[valid_pos]
    pred_scores = np.array(pred_scores)[valid_pos]
    return ndcg_score([x[pred_cids].tolist()], [pred_scores], k=K)


def custom_ap(ground_truth, pred_cids, pred_scores, K, len_gt = None):
    """
        ground_truth : set of relevant corpus ids
        pred_cids : list of predicted corpus ids
        pred_scores: list of predicted scores for the pred_cids
        K : required top K items (only needed to check and throw exception)
    """
    if K>=0:
        try:
            assert len(pred_cids)==K
        except Exception as e:
            logger.exception(e)
            logger.info(f"# ground truth={len(ground_truth)}, # preds={len(pred_cids)}")
        
    sorted_pred_scores = sorted(((e, i) for i, e in enumerate(pred_scores)), reverse=True)
    sum_precision= 0 
    positive_count = 0
    position_count = 0
    for sc, idx in sorted_pred_scores:
        position_count += 1
        #check if label=1
        if pred_cids[idx] in ground_truth: 
             positive_count +=1 
             sum_precision += (positive_count/position_count)
    if len_gt is not None: 
        if len_gt ==0: 
            average_precision = 0 
        else:
            average_precision = sum_precision/len_gt
    else:
        average_precision = sum_precision/len(ground_truth)
    return average_precision

def compute_topK_score( hash_scores, K, M=0):
    """
    """
    if len(hash_scores)==0:
        #TODO: Discuss
        total_hash_score = 0# min(nohash_scores)
    else:
        total_hash_score= np.sum(np.array(hash_scores[:K]) + M)

    return total_hash_score