import multiprocessing
import pickle
import time
from omegaconf import OmegaConf
from src.embeddings_loader import  fetch_graph_ground_truths
from lsh.ghash_main import run_lsh, compute_scores_from_cids, compute_eval_scores
import tqdm


separate_eval_specs = {"SA": ("NANL", ""),\
                       "hinge": ("GEN", "hinge"),\
                       }

if __name__ == "__main__":
    
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    hash_conf = OmegaConf.load(f"configs/hash_configs/{cli_conf.hashing.name}.yaml")
    conf = OmegaConf.merge(main_conf,model_conf, data_conf, hash_conf, cli_conf)


    SENTINEL = None 
    
    s = time.time()
 

    variations = [(0.05, 10,"Asym","AsymFmapCos")] #TODO check that always this. Hardcoded in conf.fmap_training

    c1_val = [0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 , 0.55,
          0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85]
    # c1_val = [0.05, 0.1]
    # c1_val = [0.05]

    # TODO: dataset.name=ptc_fr dataset.rel_mode=ged 
    additional_str = ""
    if conf.separate_eval is not None:
        additional_str += f"_SepEval_{conf.separate_eval}"
    if conf.hashing.name in ["Ghash2_Trained",   "Ghash2"]:
        if conf.hashing.subset_type != "same":
            additional_str += f"_SubsetType_{conf.hashing.subset_type}"
        if conf.hashing.Sm=="none":
            additional_str += f"_SmNone"
    if conf.hashing.name not in ["RH_Trained"]: 
        if conf.fmap_training.tr_fmap_loss!="BCE3":
            additional_str += f"_FmapLoss_{conf.fmap_training.tr_fmap_loss}"

    if conf.hashing.name in ["Ghash2_Trained", "Fhash_Trained"]:
        fp = f"{conf.base_dir}allPklDumps/hashResultPickles/{conf.dataset.name}_{conf.dataset.rel_mode}_{conf.model.name}_{conf.hashing.FUNC}_{conf.hashing.name}_{conf.hashcode_training.LOSS_TYPE}{conf.hashcode_training.QA_subset_size }{conf.fmap_training.tr_fmap_loss}muse{conf.hashing.m_use}{additional_str}_parallel3_med.pkl"
    elif conf.hashing.name in ["RH_Trained"]:
        fp = f"{conf.base_dir}allPklDumps/hashResultPickles/{conf.dataset.name}_{conf.dataset.rel_mode}_{conf.model.name}_{conf.hashing.FUNC}_{conf.hashing.name}_{conf.hashcode_training.LOSS_TYPE}{conf.hashcode_training.QA_subset_size }{additional_str}_parallel3_med.pkl"
    elif conf.hashing.name in ["Ghash2", "Fhash"]:
        fp = f"{conf.base_dir}allPklDumps/hashResultPickles/{conf.dataset.name}_{conf.dataset.rel_mode}_{conf.model.name}_{conf.hashing.FUNC}_{conf.hashing.name}{additional_str}_parallel3_med.pkl"

        
    
    ground_truth = fetch_graph_ground_truths(conf, "test")
    
    def inner_foo(conf,dval,d,device):
        if  conf.hashing.name not in ["Ghash2", "Fhash"]:
            if  "query_aware" in conf.hashcode_training.LOSS_TYPE :
                conf.hashcode_training.C1 = dval
            else:
                conf.hashcode_training.DECORR = dval

        tmp_list = []
        # for subset_size in [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]:s
        # create one with list till 60

        # for subset_size in [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,60]:
        for subset_size in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 25, 30, 35, 40, 45, 50, 55, 60]:
            conf.hashing.subset_size = subset_size
            hash_op = run_lsh(conf)
            # below is default setup where we evaluate the corpis_ids returned by the hashing function
            # using the same scoring function as the hashing is originally designated for
            if conf.separate_eval is None:
                all_hash_score_info = compute_scores_from_cids(conf, "test", conf.model.name, conf.hashing.FUNC,conf.dataset.rel_mode,hash_op,device)
            else: 
                all_hash_score_info = compute_scores_from_cids(conf, "test", separate_eval_specs[conf.separate_eval][0], separate_eval_specs[conf.separate_eval][1],conf.dataset.rel_mode,hash_op,device)
            eval_score_info = compute_eval_scores(all_hash_score_info, ground_truth)
            tmp_list.append(eval_score_info)
            
        d.put((dval, tmp_list))
        d.put(SENTINEL)
        return
    

    queue = multiprocessing.Queue()
    procs = []
    all_dict = {} 
    all_metric_list_parallel = []
    

    device_list = [1,2,3,4]*100

    
    for margin,v1,v2,v3 in variations: # unnecessary, but OK
        for dval,device in list(zip(c1_val,device_list)):

            p =  multiprocessing.Process(target=inner_foo, args=(conf,dval,queue,device))
            procs.append(p)
            p.start()
            
        seen_sentinel_count = 0
        while seen_sentinel_count < len(c1_val):
            a = queue.get()
            if a is SENTINEL:
                seen_sentinel_count += 1
            else:
                all_dict [a[0]] = a[1]



        for p in procs: 
            p.join()


        for dval in c1_val:
            all_metric_list_parallel.extend(all_dict[dval])
            
    pickle.dump(all_metric_list_parallel, open(fp,"wb")) 
    print(f"Saving hashing results to {fp}")
    print(time.time()-s)




# GHASH 
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="ptc_fr" dataset.rel_mode="sub_iso"  hashing.FUNC="" hashing.name="Ghash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware" hashing.subset_type="different"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="ptc_fm" dataset.rel_mode="sub_iso"  hashing.FUNC="" hashing.name="Ghash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware" hashing.subset_type="different"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="ptc_mr" dataset.rel_mode="sub_iso"  hashing.FUNC="" hashing.name="Ghash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware" hashing.subset_type="different"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="cox2" dataset.rel_mode="sub_iso"  hashing.FUNC="" hashing.name="Ghash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware" hashing.subset_type="different"


# NANL-AGGR FHASH baselines 
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="ptc_fr" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"  separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="ptc_fm" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"  separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="ptc_mr" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"  separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="cox2"   dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"  separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL" dataset.name="mutag"   dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware" separate_eval="SA"

# NANL-AGGR RH baselines 
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL"   dataset.name="ptc_fr" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware" separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL"   dataset.name="ptc_fm" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware" separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL"   dataset.name="ptc_mr" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware" separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL"   dataset.name="cox2"   dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware" separate_eval="SA"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="NANL"   dataset.name="mutag"  dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware" separate_eval="SA"



# FHASH baselines

# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="ptc_fr" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="ptc_fm" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="ptc_mr" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained" hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="cox2" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained"   hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="mutag" dataset.rel_mode="sub_iso"  hashing.FUNC="hinge" hashing.name="Fhash_Trained"  hashcode_training.QA_subset_size=8 fmap_training.tr_fmap_loss="BCE3" hashcode_training.LOSS_TYPE="query_aware"




# COS baselines
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="ptc_fr" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="ptc_fm" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="ptc_mr" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="cox2" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware"
# python -m scripts.ghash_lsh_scripts.hashing_ghash3_big model.name="GEN"   dataset.name="mutag" dataset.rel_mode="sub_iso"  hashing.FUNC="cos" hashing.name="RH_Trained" hashcode_training.QA_subset_size=8  hashcode_training.LOSS_TYPE="query_aware"




