import argparse
import pickle
# from common import logger, set_log
import time
import networkx as nx
import numpy as np
import os
import sys
import random
import scipy.stats
from pebble import ProcessPool
from concurrent.futures import TimeoutError
import itertools

from torch_geometric.datasets import TUDataset
import torch_geometric.utils as pyg_utils
import networkx.algorithms.isomorphism as iso
from utils.data_utils import get_nx_graph, OnTheFlySubgraphSampler,fetch_tudataset_graphs, random_graph_generator
from dataclasses import dataclass
from omegaconf import OmegaConf
import random
import scipy.stats
import numpy as np
from loguru import logger
from utils.data_utils import get_nx_graph, fetch_tudataset_graphs
from utils.utils import set_seed
import datetime
import tqdm

def check_subiso(data):
  gc,gq = data
  return nx.algorithms.isomorphism.GraphMatcher(gc, gq).subgraph_is_isomorphic()


def run_parallel_pool(func,input_list):
    all_result = []
    len_list = len(input_list)
    
    with ProcessPool(max_workers=100) as pool:
        future = pool.map(func, input_list, timeout=30)
        iterator = future.result()

        for c_i in range(len_list):
              try:
                  result = next(iterator)
                  all_result.append(result)

              except StopIteration:
                  break
              except TimeoutError as error:
                  all_result.append(None)
              except AssertionError:
                  all_result.append(None)
    return all_result

# def check_iso(data):
#     gc,gq = data
#     return iso.GraphMatcher(gc, gq).subgraph_is_isomorphic()

if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    # model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    # conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(main_conf, data_conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    
    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    # torch.backends.cuda.matmul.allow_tf32 = False
    # torch.backends.cudnn.allow_tf32 = False

    
    # task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    # logger.info(task_name)
    # conf.task_name = task_name
    
    set_seed(conf.training.seed, conf)
    
    
    graphs = fetch_tudataset_graphs(conf)


    # no_of_query_subgraphs = 300
    # no_of_corpus_subgraphs = 800
    
    
    no_of_corpus_subgraphs = 10000
    aug_num_cgraphs = 10 * no_of_corpus_subgraphs

    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed"
    if not os.path.exists(directory):
        os.makedirs(directory)
    c_fname = f"{directory}/{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}.pkl"
    # print(directory, c_fname)

    if os.path.isfile(c_fname):
        (subgraph_list,anchor_list,subgraph_id_list ) = pickle.load(open(c_fname,"rb"))
        logger.info(f"Loaded {len(subgraph_list)} corpus subgraphs from {c_fname}")
    else:
        logger.info("Sampling corpus subgraphs")
        subgraph_list, anchor_list, subgraph_id_list = [], [], []
        subgraph_sampler = OnTheFlySubgraphSampler(graphs,conf.dataset.min_cgraph_nodes,conf.dataset.max_cgraph_nodes)
        for i in range(no_of_corpus_subgraphs):
            while True:
                dup_flag = False
                sgraph,anchor,sgraph_id = subgraph_sampler.sample_subgraph()
                for c_graph in subgraph_list:
                    if nx.is_isomorphic(sgraph,c_graph):
                        dup_flag = True
                        break
                if dup_flag:
                    # logger.info("corpus graph Discarded due to duplicacy")
                    continue
                else:
                    logger.info(f"corpus graph {i} sampled")
                subgraph_list.append(sgraph)
                anchor_list.append(anchor)
                subgraph_id_list.append(sgraph_id)
                break
        
        pickle.dump((subgraph_list,anchor_list,subgraph_id_list),open(c_fname,"wb"))
        logger.info(f"Saved {len(subgraph_list)} corpus subgraphs to {c_fname}")
        

    cfname_aug = f"{directory}/{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_aug_with_dup.pkl"

    if os.path.isfile(cfname_aug):
        new_cgrs = pickle.load(open(cfname_aug,"rb"))
        logger.info(f"Loaded {len(new_cgrs)} augmented corpus subgraphs with duplicates from {cfname_aug}")
    else:    
        new_cgrs = []
        for nn in [1,2]:
            for ne in [3,4,5]:
                for idx in tqdm.tqdm(range(len(subgraph_list))):
                    for i in range(5):
                        #for i in range(400//len(qgraphs_10)):
                        new_cgraph = random_graph_generator(subgraph_list[idx],num_new_nodes=nn, max_edges_per_node=ne,a=0,b=2)

                        new_cgrs.append(new_cgraph)
    
        pickle.dump(new_cgrs,open(cfname_aug,"wb"))
    
    cfname_aug = f"{directory}/{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_aug_without_dup.pkl"

    if os.path.isfile(cfname_aug):
        unique_cgrs = pickle.load(open(cfname_aug,"rb"))
        logger.info(f"Loaded {len(unique_cgrs)} augmented corpus subgraphs without duplicates from {cfname_aug}")
    
    else:

        unique_cgrs = []
        cnt=0
        for g in tqdm.tqdm(new_cgrs):
            dup  = False
            for g1 in unique_cgrs:
                if nx.is_isomorphic(g,g1):
                    dup = True
                    break
            if not dup:
                unique_cgrs.append(g)
                cnt+=1
                print(f'\r num_unique = {cnt}', end= ' ')
            if cnt%1000 == 0:
                logger.info(f"num_unique = {cnt} Saving now")
                pickle.dump(unique_cgrs,open(cfname_aug,"wb"))
                
            if cnt == aug_num_cgraphs:
                break
            
        pickle.dump(unique_cgrs,open(cfname_aug,"wb"))
        
    
    assert len(unique_cgrs) >= aug_num_cgraphs
    
    #TODO IR: obtain first 100K and proceed with remaining query and groundtruth generation .
    
    final_cgraphs = unique_cgrs[0:aug_num_cgraphs]
    logger.info(f"retaining {len(final_cgraphs)} corpus subgraphs for query generation")
    
    qgr_fname = f"{directory}/all_queries_for_{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_aug_without_dup.pkl"
    sample_more_queries  = False
    if os.path.isfile(qgr_fname):
        all_qgrs,all_gt_ratios,rel_dict = pickle.load(open(qgr_fname,"rb"))
        n_queries = 0
        n_queries_within_ratio = 0 
        for r in all_gt_ratios:
            if r > conf.dataset.MinR and r  < conf.dataset.MaxR:
                n_queries_within_ratio += 1
            n_queries = n_queries + 1
        logger.info(f"Loaded {len(all_qgrs)} query graphs without duplicates from {qgr_fname} with {n_queries_within_ratio} within ratio and {n_queries} total queries")
    else:
        all_qgrs = []
        all_gt_ratios = []
        rel_dict  = {} 
        n_queries = 0
        n_queries_within_ratio = 0 
        
        
    if n_queries_within_ratio <500: 
        logger.info(f"Sampling more queries because {n_queries_within_ratio} queries withing ratio which is < 500")
        sample_more_queries = True
    
    
    
    if sample_more_queries:
        #TODO :sample query graphs 
        subgraph_sampler = OnTheFlySubgraphSampler(graphs,conf.dataset.min_qgraph_nodes,conf.dataset.max_qgraph_nodes)
        
        while  n_queries_within_ratio < 500:
            s = time.time()
            sgraph,_,_ = subgraph_sampler.sample_subgraph()
            dup_flag = False
            for qgr in all_qgrs:
                if nx.is_isomorphic(sgraph,qgr):
                    dup_flag = True
                    logger.info(f"SAD: found duplicate query graph")
                    break
            if not dup_flag: 
                res = run_parallel_pool(check_subiso, list(zip(final_cgraphs,itertools.repeat(sgraph))))
                pos_c = []
                neg_c = []
                for c_i in range(aug_num_cgraphs): 
                    if res[c_i] ==1:
                        pos_c.append(c_i)
                    else:
                        neg_c.append(c_i)
                if len(neg_c) ==0:
                    r = 0 
                else:
                    r = len(pos_c)/len(neg_c)
                logger.info(f"SAD: found non duplicate query graph with ration {r} in {time.time()-s} seconds")
                if len(pos_c) > 0 and r <= conf.dataset.MaxR:
                    
                    logger.info(f"Found #{n_queries} query, with {len(pos_c)} positives, {r} ratio, {len(sgraph)} nodes, time taken {time.time()-s}")
                    all_qgrs.append(sgraph)
                    all_gt_ratios.append(r)
                    rel_dict[n_queries] = {}
                    rel_dict[n_queries]['pos'] = pos_c
                    rel_dict[n_queries]['neg'] = neg_c
                    n_queries = n_queries + 1
                    #just some stopping criteria
                    if r > conf.dataset.MinR and r  < conf.dataset.MaxR:
                        n_queries_within_ratio += 1
                        logger.info(f"Found #{n_queries_within_ratio} query within ratio")
                    if n_queries%20 ==0 :
                        pickle.dump(( all_qgrs,all_gt_ratios,rel_dict), open(qgr_fname,"wb"))
                
        #dump again for the remenants
        pickle.dump(( all_qgrs,all_gt_ratios,rel_dict), open(qgr_fname,"wb"))

    
    
    
    #TODO: relabel the nodes from 0 --> n-1
    c_gr_all = [nx.convert_node_labels_to_integers(o) for o in final_cgraphs]
    all_query_graphs = [nx.convert_node_labels_to_integers(o) for o in all_qgrs]
    
    final_qgr_fname = f"{directory}/relabeled_queries_for_{aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
    final_cgr_fname = f"{directory}/{aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}_relabeled.pkl"
    final_rel_fname = f"{directory}/rel_dict_for_{aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_cgraph_nodes}_max{conf.dataset.max_cgraph_nodes}.pkl"
    pickle.dump(rel_dict,open(final_rel_fname,"wb"))
    pickle.dump(all_query_graphs,open(final_qgr_fname,"wb"))
    pickle.dump(c_gr_all,open(final_cgr_fname,"wb"))
    logger.info(f"Saved final relabeled queries to {final_qgr_fname}")
    logger.info(f"Saved final relabeled corpus to {final_cgr_fname}")
    logger.info(f"Saved final rel dict to {final_rel_fname}")

    
        

#     fp = av.DIR_PATH + "/data/graph_data/preprocessed/" + av.DATASET_NAME +"240k_corpus_subgraphs_" + \
#                 str(no_of_corpus_subgraphs)+"_min_"+str(av.MIN_CORPUS_SUBGRAPH_SIZE) + "_max_" + \
#                 str(av.MAX_CORPUS_SUBGRAPH_SIZE)+".pkl"
#     logger.info("Sampling corpus subgraphs")
#     subgraph_list, anchor_list, subgraph_id_list = [], [], []
#     av.MIN_SUBGRAPH_SIZE = av.MIN_CORPUS_SUBGRAPH_SIZE
#     av.MAX_SUBGRAPH_SIZE = av.MAX_CORPUS_SUBGRAPH_SIZE
#     subgraph_sampler = OnTheFlySubgraphSampler(av,graphs)
#     for i in range(no_of_corpus_subgraphs):
#         while True:
#             dup_flag = False
#             sgraph,anchor,sgraph_id = subgraph_sampler.sample_subgraph()
#             for c_graph in subgraph_list:
#                 if nx.is_isomorphic(sgraph,c_graph):
#                     dup_flag = True
#                     break
#             if dup_flag:
#                 logger.info("corpus graph Discarded due to duplicacy")
#                 continue
#             subgraph_list.append(sgraph)
#             anchor_list.append(anchor)
#             subgraph_id_list.append(sgraph_id)
#             break
#     corpus_subgraph_list,corpus_anchor_list, corpus_subgraph_id_list =subgraph_list,anchor_list,subgraph_id_list

#     print(fp)
#     with open(fp, 'wb') as f:
#         pickle.dump((corpus_subgraph_list,corpus_anchor_list,corpus_subgraph_id_list),f)
    
#     fp_query = av.DIR_PATH + "/data/graph_data/preprocessed/" + av.DATASET_NAME +"240k_query_subgraphs_" + \
#               str(no_of_query_subgraphs)+"_min_"+str(av.MIN_QUERY_SUBGRAPH_SIZE) + "_max_" + \
#               str(av.MAX_QUERY_SUBGRAPH_SIZE)+".pkl"
    
#     subgraph_rel_type = "nx_is_subgraph_iso"
#     fp_rel_dict = av.DIR_PATH + "/data/graph_data/preprocessed/" + av.DATASET_NAME + "240k_query_" + \
#           str(no_of_query_subgraphs)+ "_corpus_" + str(no_of_corpus_subgraphs) + \
#           "_minq_"+str(av.MIN_QUERY_SUBGRAPH_SIZE) + "_maxq_" + \
#           str(av.MAX_QUERY_SUBGRAPH_SIZE)+"_minc_"+str(av.MIN_CORPUS_SUBGRAPH_SIZE) + "_maxc_" + \
#           str(av.MAX_CORPUS_SUBGRAPH_SIZE)+"_rel_" + subgraph_rel_type +".pkl"
    
#     print("Sampling query subgraphs")

#     av.MIN_SUBGRAPH_SIZE = av.MIN_QUERY_SUBGRAPH_SIZE
#     av.MAX_SUBGRAPH_SIZE = av.MAX_QUERY_SUBGRAPH_SIZE
#     subgraph_sampler = OnTheFlySubgraphSampler(av,graphs)
    
#     if os.path.isfile(fp_query):
#         logger.info("Loading query subgraphs from %s", fp_query)
#         query_subgraph_list,query_anchor_list, query_subgraph_id_list = pickle.load(open(fp_query,"rb"))
#         logger.info("Loading rels from %s", fp_rel_dict)
#         rel_dict = pickle.load(open(fp_rel_dict,"rb"))
#         n_queries = len(query_subgraph_list)
#     else:
#         query_subgraph_list,query_anchor_list, query_subgraph_id_list = [],[],[]
#         rel_dict = {}
#         n_queries =0
        
#     sstart = time.time()
    
    

#     with ProcessPool(max_workers=30) as pool:
#         while n_queries <no_of_query_subgraphs:
#             start = time.time()

#             sgraph,anchor,sgraph_id = subgraph_sampler.sample_subgraph()
#             dup_flag = False
#             for q_graph in query_subgraph_list:
#                 if nx.is_isomorphic(sgraph,q_graph):
#                     dup_flag = True
#                     break
#             if dup_flag:
#                 logger.info("Discarded due to duplicacy")
#                 continue
#             pos_c,neg_c = [],[]


#             future = pool.map(check_iso, zip(corpus_subgraph_list,itertools.repeat(sgraph)), timeout=30)
#             iterator = future.result()

#             for c_i in range(no_of_corpus_subgraphs): 
#                 try:
#                     result = next(iterator)
#                     if result ==1:
#                         pos_c.append(c_i)
#                     else:
#                         neg_c.append(c_i)

#                 except StopIteration:
#                     break
#                 except TimeoutError as error:  
#                     neg_c.append(c_i)
# #                     logger.info(str(c_i)+" " + "Timeout")


#             if len(neg_c) ==0:
#                 r = 0
#             else:  
#                 r = len(pos_c)/len(neg_c)
#             if (r>=0.1 and r<=0.4):
#                 logger.info("q: {}, r ratio : {}".format(n_queries, r))
#                 query_subgraph_list.append(sgraph)
#                 query_anchor_list.append(anchor)
#                 query_subgraph_id_list.append(sgraph_id)
#                 rel_dict[n_queries] = {}
#                 rel_dict[n_queries]['pos'] = pos_c
#                 rel_dict[n_queries]['neg'] = neg_c
                
#                 logger.info("Adding sampled query subgraph id: {} to query graph pickle file".format(n_queries))
#                 with open(fp_query, 'wb') as f:
#                     pickle.dump((query_subgraph_list,query_anchor_list,query_subgraph_id_list),f)
                    
#                 logger.info("Adding sampled query subgraph id: {} alignment info to rels pickle file".format(n_queries))
#                 with open(fp_rel_dict, "wb") as f:
#                     pickle.dump(rel_dict,f)
#                 n_queries = n_queries+1
#             else:
#                 logger.info("Discarded due to r ratio : {}".format(r))
# #             logger.info("time to decide: {}".format(time.time()-start))

# #     logger.info("Total time: %s", time.time()-sstart)
    
    
#     q_graph_ids = list(range(no_of_query_subgraphs))
#     end1 = int(0.6*no_of_query_subgraphs)
#     end2 = end1 + int(0.15*no_of_query_subgraphs)
    
#     train_q_ids = q_graph_ids[0:end1]
#     val_q_ids = q_graph_ids[end1:end2]
#     test_q_ids = q_graph_ids[end2:]
#     c_graph_ids = list(range(no_of_corpus_subgraphs))
    
#     train_q_gr,test_q_gr,val_q_gr = [],[],[]
#     train_q_rels,test_q_rels,val_q_rels = {},{},{}
#     corpus_gr = []
    
#     for idx in train_q_ids:
#         g = query_subgraph_list[idx]
#         g1 = nx.convert_node_labels_to_integers(g)
#         train_q_gr.append(g1)
#         train_q_rels[idx-train_q_ids[0]] = rel_dict[idx]
        
#     for idx in test_q_ids:
#         g = query_subgraph_list[idx]
#         g1 = nx.convert_node_labels_to_integers(g)
#         test_q_gr.append(g1)
#         test_q_rels[idx-test_q_ids[0]] = rel_dict[idx]
        
#     for idx in val_q_ids:
#         g = query_subgraph_list[idx]
#         g1 = nx.convert_node_labels_to_integers(g)
#         val_q_gr.append(g1)
#         val_q_rels[idx-val_q_ids[0]] = rel_dict[idx]
        
#     for idx in c_graph_ids:
#         g = corpus_subgraph_list[idx]
#         g1 = nx.convert_node_labels_to_integers(g)
#         corpus_gr.append(g1)    
        
#     fp = av.DIR_PATH + "/data/graph_data/splits/train/train_" + av.DATASET_NAME +"240k_query_subgraphs.pkl" 
#     pickle.dump(train_q_gr,open(fp,"wb"))
#     fp = av.DIR_PATH + "/data/graph_data/splits/train/train_" + av.DATASET_NAME + "240k_rel_" + subgraph_rel_type +".pkl"
#     pickle.dump(train_q_rels,open(fp,"wb"))

#     fp = av.DIR_PATH + "/data/graph_data/splits/test/test_" + av.DATASET_NAME +"240k_query_subgraphs.pkl" 
#     pickle.dump(test_q_gr,open(fp,"wb"))
#     fp = av.DIR_PATH + "/data/graph_data/splits/test/test_" + av.DATASET_NAME + "240k_rel_" + subgraph_rel_type +".pkl"
#     pickle.dump(test_q_rels,open(fp,"wb"))

#     fp = av.DIR_PATH + "/data/graph_data/splits/val/val_" + av.DATASET_NAME +"240k_query_subgraphs.pkl" 
#     pickle.dump(val_q_gr,open(fp,"wb"))
#     fp = av.DIR_PATH + "/data/graph_data/splits/val/val_" + av.DATASET_NAME + "240k_rel_" + subgraph_rel_type +".pkl"
#     pickle.dump(val_q_rels,open(fp,"wb"))

#     fp = av.DIR_PATH + "/data/graph_data/splits/" + av.DATASET_NAME +"240k_corpus_subgraphs.pkl" 
#     pickle.dump(corpus_gr,open(fp,"wb"))
