import multiprocessing
import pickle
import time
from omegaconf import OmegaConf
from src.embeddings_loader import  fetch_graph_ground_truths, fetch_graph_corpus_embeddings
from lsh.ghash_main import run_faiss, compute_scores_from_cids, compute_eval_scores
from utils.utils import *


SENTINEL = None 

def inner_foo(conf,topk,d,device,ground_truth):
        
    tmp_list = []
    
    conf.hashing.top_K = topk
    hash_op_union, hash_op_inter = run_faiss(conf,device)
    
    all_hash_score_info_union = compute_scores_from_cids(conf, "test", "NANL", "",conf.dataset.rel_mode,hash_op_union,device)
    eval_score_info_union = compute_eval_scores(all_hash_score_info_union, ground_truth)
    tmp_list.append(eval_score_info_union)
    
    all_hash_score_info_inter = compute_scores_from_cids(conf, "test", "NANL", "",conf.dataset.rel_mode,hash_op_inter,device)
    eval_score_info_inter = compute_eval_scores(all_hash_score_info_inter, ground_truth)       
    tmp_list.append(eval_score_info_inter)
        
    d.put((topk, tmp_list))
    d.put(SENTINEL)
    return


if __name__ == "__main__":
    multiprocessing.set_start_method('spawn')

    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    hash_conf = OmegaConf.load(f"configs/hash_configs/{cli_conf.hashing.name}.yaml")
    conf = OmegaConf.merge(main_conf,model_conf, data_conf, hash_conf, cli_conf)


    assert conf.hashing.name == "Faiss"
    assert conf.hashing.faissmetric in ["l2", "cosine"]



    s = time.time()
 
    if conf.dataset.rel_mode == "sub_iso":
        top_Ks = [10, 25, 50, 75, 100, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 25000, 50000, 75000, 100000, 250000, 500000, 750000, 1000000,
              1500000, 1800000]
    elif conf.dataset.rel_mode == "ged":
        top_Ks = [10, 25, 50, 75, 100, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 25000, 50000, 75000, 100000, 250000, 500000, 750000, 1000000]
    elif conf.dataset.rel_mode == "uneq_ged":
        top_Ks = [10, 25, 50, 75, 100, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 25000, 50000, 75000, 100000, 250000, 500000, 750000, 1000000]
    else:
        raise ValueError(f" rel_mode {conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")  


    fp = f"{conf.base_dir}allPklDumps/hashResultPickles/{conf.dataset.name}_{conf.dataset.rel_mode}_{conf.model.name}_{conf.hashing.name}_{conf.hashing.faissmetric}.pkl"
    
    
    
    ground_truth = fetch_graph_ground_truths(conf, "test")
    
    # #Doing it once so that the index is created once before parallel triggers (same index is shared across all top_Ks)
    # corpus_embeds = fetch_graph_corpus_embeddings(conf)
    # directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed/splits"
    # mname, scoring_fn = ("NANL", "")
    # assert conf.model.name == mname  

    # if conf.dataset.rel_mode == "sub_iso":
    #     aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
    # elif conf.dataset.rel_mode == "ged":
    #     aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    # elif conf.dataset.rel_mode == "uneq_ged":
    #     aux_fname = f"{directory}/../{mname}_{scoring_fn}_{conf.dataset.rel_mode}_aux_info_for_{conf.dataset.aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    # else:
    #     raise ValueError(f" rel_mode {conf.dataset.rel_mode} should be either sub_iso or ged or uneq_ged")

    # aux_info_dict = pickle.load(open(aux_fname, "rb"))
    # corpus_mask = aux_info_dict['corpus_mask']
    
    # disann_instance = get_class(f"{conf.hashing.classPath}.{conf.hashing.name}")(conf)#.to(conf.hashing.device)

    # #This will generate feature maps and index corpus items
    # disann_instance.index_corpus(corpus_embeds,corpus_mask)    
    
    


    queue = multiprocessing.Queue()
    procs = []
    all_dict = {} 
    all_metric_list_parallel = []
    
    # device_list = [0,1,2,3,5,6]*100s
    # device_list = [0,1,2,3,4,5]*100
    device_list = [0,1,2]*100
    # device_list = [2,4,5,6]*100
    # device_list = [0, 1,2,3,4 ]*100
    
    
    for topk,device in list(zip(top_Ks,device_list)):

        p =  multiprocessing.Process(target=inner_foo, args=(conf,topk,queue,device,ground_truth))
        procs.append(p)
        p.start()
        
    seen_sentinel_count = 0
    while seen_sentinel_count < len(top_Ks):
        a = queue.get()
        if a is SENTINEL:
            seen_sentinel_count += 1
        else:
            all_dict [a[0]] = a[1]



    for p in procs: 
        p.join()


    for topk in top_Ks:
        all_metric_list_parallel.extend(all_dict[topk])
            
    pickle.dump(all_metric_list_parallel, open(fp,"wb")) 
    print(f"Saving hashing results to {fp}")
    print(time.time()-s)



# python -m scripts.ghash_lsh_scripts.hashing_faiss_fox model.name="NANL" dataset.name="ptc_fr" dataset.rel_mode="sub_iso"   hashing.name="Faiss" hashing.faissmetric="cosine" 
# python -m scripts.ghash_lsh_scripts.hashing_faiss_fox model.name="NANL" dataset.name="ptc_fr" dataset.rel_mode="sub_iso"   hashing.name="Faiss" hashing.faissmetric="l2" 
