import pickle
from loguru import logger
import numpy as np

dataset_path_map = {
    "sub_iso": "data",
    "ged": "ged_data",
    "uneq_ged": "uneq_ged_data"
}

def fetch_ground_truths(conf, mode):
    logger.info('Fetching ground truth.')
    fp = f"{conf.base_dir}fhash_data/{conf.dataset.name}_embeds_{conf.hashing.FUNC}.pkl"
    all_d = pickle.load(open(fp,"rb"))
    logger.info(f"Loading ground truth labels from {fp}")
    return all_d[f'{mode}_positive_labels']


def fetch_query_embeddings(conf, mode):
    logger.info('Fetching query embeddings.')
    embed_fp = f"{conf.base_dir}fhash_data/{conf.dataset.name}_embeds_{conf.hashing.FUNC}.pkl"
    all_d = pickle.load(open(embed_fp,"rb"))
    logger.info(f"From {embed_fp}")
    return all_d[f'{mode}_q']

def fetch_corpus_embeddings(conf):
    logger.info('Fetching corpus embeddings.')
    embed_fp = f"{conf.base_dir}fhash_data/{conf.dataset.name}_embeds_{conf.hashing.FUNC}.pkl"
    all_d = pickle.load(open(embed_fp,"rb"))
    logger.info(f"From {embed_fp}", embed_fp)
    return all_d['all_c'].astype(dtype=np.float32)


def fetch_graph_query_embeddings(conf, mode):
    mname = conf.model.name
    scoring_fn = "" if conf.hashing.FUNC is None else conf.hashing.FUNC
    rel_mode = conf.dataset.rel_mode
    assert conf.dataset.path == dataset_path_map[rel_mode]
    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"

    #If we want to forcibly apply existing single vector hashing method on NANL trained embeddings then return aggregated embeddings
    if mname == "NANL" and scoring_fn!="":
        conf.hashing.FUNC = ""
        q_embeds = fetch_graph_query_embeddings(conf,mode).mean(axis=1)
        logger.info(f"Returning aggregated embeddings for {mname} with {scoring_fn} scoring function")
        conf.hashing.FUNC = scoring_fn
        return q_embeds

    if rel_mode == "sub_iso":
        q_fname = f"{directory}/{mode}/{mname}_{scoring_fn}_{rel_mode}_embeds_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
    elif rel_mode == "ged":
        q_fname = f"{directory}/{mode}/{mname}_{scoring_fn}_{rel_mode}_embeds_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    elif rel_mode == "uneq_ged": #TODO: can remove duplicaton here
        q_fname = f"{directory}/{mode}/{mname}_{scoring_fn}_{rel_mode}_embeds_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    else:
        raise NotImplementedError
    q_embeds = pickle.load(open(q_fname, "rb"))
    return q_embeds.cpu().numpy()

def fetch_graph_corpus_embeddings(conf):
    mname = conf.model.name
    scoring_fn = "" if conf.hashing.FUNC is None else conf.hashing.FUNC
    rel_mode = conf.dataset.rel_mode
    assert conf.dataset.path == dataset_path_map[rel_mode]
    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
    
    #If we want to forcibly apply existing single vector hashing method on NANL trained embeddings then return aggregated embeddings
    if mname == "NANL" and scoring_fn!="" and conf.hashing.name!="DiskANN":
        print(f"Here because mname is {mname} and scoring_fn is {scoring_fn}")
        conf.hashing.FUNC = ""
        c_embeds = fetch_graph_corpus_embeddings(conf).mean(axis=1)
        logger.info(f"Returning aggregated embeddings for {mname} with {scoring_fn} scoring function")
        conf.hashing.FUNC = scoring_fn
        return c_embeds
        
    if rel_mode == "sub_iso" :
        c_fname = f"{directory}/../{mname}_{scoring_fn}_{rel_mode}_embeds_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
    elif rel_mode == "ged":
        c_fname = f"{directory}/../{mname}_{scoring_fn}_{rel_mode}_embeds_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    elif rel_mode == "uneq_ged": #TODO: can remove duplicaton here
        c_fname = f"{directory}/../{mname}_{scoring_fn}_{rel_mode}_embeds_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    else:
        raise NotImplementedError

    c_embeds = pickle.load(open(c_fname, "rb"))
    return c_embeds.cpu().numpy()

def fetch_graph_ground_truths(conf, mode):
    rel_mode = conf.dataset.rel_mode
    assert conf.dataset.path == dataset_path_map[rel_mode]
    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
    if rel_mode == "sub_iso" :
        rel_fname = f"{directory}/{mode}/{mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
    elif rel_mode == "ged":
        rel_fname = f"{directory}/{mode}/{mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    elif rel_mode == "uneq_ged": #TODO: can remove duplicaton here
        rel_fname = f"{directory}/{mode}/{mode}_rel_dict_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
    else:
        raise NotImplementedError()
    gt_dict = pickle.load(open(rel_fname, "rb"))
    pos_gt_dict = {k:gt_dict[k]['pos'] for k in gt_dict.keys()}
    return pos_gt_dict

def dict_to_device(d, device): 
    for k,v in d.items():
        #if v is a torch tensor
        if hasattr(v, 'to'):
            d[k] = v.to(device)
    return d

def fetch_all_info_for_scoring(conf, mode, model_name, scoring_fn, rel_mode, device=None):
    """
        Use conf only to refer to dataset config
        model_name --> NANL/GEN
        scoring_fn -> "", "sigmoid", "asym", "dot", "hinge", "wjac", "cos"
        rel_mode --> sub_iso/ged
        
    """
    all_info = {}
    assert conf.dataset.path == dataset_path_map[rel_mode]
    if scoring_fn is None:
        scoring_fn = ""
    if rel_mode == "sub_iso":
        directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
        c_fname = f"{directory}/../{model_name}_{scoring_fn}_{rel_mode}_embeds_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
        c_embeds = pickle.load(open(c_fname, "rb"))
        if device is not None:
            c_embeds = c_embeds.to(device)
        all_info['corpus_embeds'] = c_embeds
        
        q_fname = f"{directory}/{mode}/{model_name}_{scoring_fn}_{rel_mode}_embeds_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
        q_embeds = pickle.load(open(q_fname, "rb"))
        if device is not None:
            q_embeds = q_embeds.to(device)
        all_info['query_embeds'] = q_embeds


        if model_name== "NANL":
            c_aux_fname = f"{directory}/../{model_name}_{scoring_fn}_{rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
            c_aux_info_dict = pickle.load(open(c_aux_fname, "rb"))
            if device is not None:
                c_aux_info_dict = dict_to_device(c_aux_info_dict, device)
            all_info['corpus_aux_info'] = c_aux_info_dict
            q_aux_fname = f"{directory}/{mode}/{model_name}_{scoring_fn}_{rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled_range_{conf.dataset.MinR}_to_{conf.dataset.MaxR}.pkl"
            q_aux_info_dict = pickle.load(open(q_aux_fname, "rb"))
            if device is not None:
                q_aux_info_dict = dict_to_device(q_aux_info_dict, device)
            all_info['query_aux_info'] = q_aux_info_dict
    elif rel_mode == "ged":
        directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
        c_fname = f"{directory}/../{model_name}_{scoring_fn}_{rel_mode}_embeds_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        c_embeds = pickle.load(open(c_fname, "rb"))
        if device is not None:
            c_embeds = c_embeds.to(device)
        all_info['corpus_embeds'] = c_embeds
        
        q_fname = f"{directory}/{mode}/{model_name}_{scoring_fn}_{rel_mode}_embeds_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        q_embeds = pickle.load(open(q_fname, "rb"))
        if device is not None:
            q_embeds = q_embeds.to(device)
        all_info['query_embeds'] = q_embeds


        if model_name== "NANL":
            c_aux_fname = f"{directory}/../{model_name}_{scoring_fn}_{rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
            c_aux_info_dict = pickle.load(open(c_aux_fname, "rb"))
            if device is not None:
                c_aux_info_dict = dict_to_device(c_aux_info_dict, device)
            all_info['corpus_aux_info'] = c_aux_info_dict
            q_aux_fname = f"{directory}/{mode}/{model_name}_{scoring_fn}_{rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            q_aux_info_dict = pickle.load(open(q_aux_fname, "rb"))
            if device is not None:
                q_aux_info_dict = dict_to_device(q_aux_info_dict, device)
            all_info['query_aux_info'] = q_aux_info_dict
    elif rel_mode == "uneq_ged": #TODO: can remove duplicaton here
        directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
        c_fname = f"{directory}/../{model_name}_{scoring_fn}_{rel_mode}_embeds_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
        c_embeds = pickle.load(open(c_fname, "rb"))
        if device is not None:
            c_embeds = c_embeds.to(device)
        all_info['corpus_embeds'] = c_embeds
        
        q_fname = f"{directory}/{mode}/{model_name}_{scoring_fn}_{rel_mode}_embeds_queries_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
        q_embeds = pickle.load(open(q_fname, "rb"))
        if device is not None:
            q_embeds = q_embeds.to(device)
        all_info['query_embeds'] = q_embeds
        
        if model_name== "NANL":
            c_aux_fname = f"{directory}/../{model_name}_{scoring_fn}_{rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
            c_aux_info_dict = pickle.load(open(c_aux_fname, "rb"))
            if device is not None:
                c_aux_info_dict = dict_to_device(c_aux_info_dict, device)
            all_info['corpus_aux_info'] = c_aux_info_dict
            q_aux_fname = f"{directory}/{mode}/{model_name}_{scoring_fn}_{rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled_MaxSkew_{conf.dataset.MaxSkew}.pkl"
            q_aux_info_dict = pickle.load(open(q_aux_fname, "rb"))
            if device is not None:
                q_aux_info_dict = dict_to_device(q_aux_info_dict, device)
            all_info['query_aux_info'] = q_aux_info_dict
    else:
        raise NotImplementedError
    
    return all_info
                
