import multiprocessing
import pickle
import time
from omegaconf import OmegaConf
from src.embeddings_loader import  fetch_graph_ground_truths
from lsh.ghash_main import run_random, compute_scores_from_cids, compute_eval_scores
from utils.utils import *


SENTINEL = None 

def inner_foo(conf,topk,d,device,ground_truth):
        
    tmp_list = []
    
    conf.hashing.top_K = topk
    hash_op = run_random(conf)

    
    all_hash_score_info = compute_scores_from_cids(conf, "test", "NANL", "",conf.dataset.rel_mode,hash_op,device)
    eval_score_info = compute_eval_scores(all_hash_score_info, ground_truth)       
    tmp_list.append(eval_score_info)
        
    d.put((topk, tmp_list))
    d.put(SENTINEL)
    return


if __name__ == "__main__":
    multiprocessing.set_start_method('spawn')

    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    hash_conf = OmegaConf.load(f"configs/hash_configs/{cli_conf.hashing.name}.yaml")
    conf = OmegaConf.merge(main_conf,model_conf, data_conf, hash_conf, cli_conf)






    s = time.time()
 
    top_Ks = [10,100,1000,2000] + list(range(5000, 100000, 5000))



    fp = f"{conf.base_dir}allPklDumps/hashResultPickles/{conf.dataset.name}_{conf.dataset.rel_mode}_{conf.model.name}_{conf.hashing.name}.pkl"
    
    
    
    ground_truth = fetch_graph_ground_truths(conf, "test")

    
    


    queue = multiprocessing.Queue()
    procs = []
    all_dict = {} 
    all_metric_list_parallel = []
    

    device_list = [0,1,2]*100

    
    
    for topk,device in list(zip(top_Ks,device_list)):

        p =  multiprocessing.Process(target=inner_foo, args=(conf,topk,queue,device,ground_truth))
        procs.append(p)
        p.start()
        
    seen_sentinel_count = 0
    while seen_sentinel_count < len(top_Ks):
        a = queue.get()
        if a is SENTINEL:
            seen_sentinel_count += 1
        else:
            all_dict [a[0]] = a[1]



    for p in procs: 
        p.join()


    for topk in top_Ks:
        all_metric_list_parallel.extend(all_dict[topk])
            
    pickle.dump(all_metric_list_parallel, open(fp,"wb")) 
    print(f"Saving hashing results to {fp}")
    print(time.time()-s)



# python -m scripts.ghash_lsh_scripts.hashing_faiss_fox model.name="NANL" dataset.name="ptc_fr" dataset.rel_mode="sub_iso"   hashing.name="Faiss" hashing.faissmetric="cosine" 
# python -m scripts.ghash_lsh_scripts.hashing_faiss_fox model.name="NANL" dataset.name="ptc_fr" dataset.rel_mode="sub_iso"   hashing.name="Faiss" hashing.faissmetric="l2" 
