import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import sys
sys.path.insert(0, 'lib')
import pyged
import argparse
import pickle
# from common import logger, set_log
import time
import networkx as nx
import numpy as np
import os
import sys
import random
import scipy.stats
from pebble import ProcessPool
from concurrent.futures import TimeoutError
import itertools

from torch_geometric.datasets import TUDataset
import torch_geometric.utils as pyg_utils
import networkx.algorithms.isomorphism as iso
from utils.data_utils import get_nx_graph, OnTheFlySubgraphSampler,fetch_tudataset_graphs, random_graph_generator
from dataclasses import dataclass
from omegaconf import OmegaConf
import random
import scipy.stats
import numpy as np
from loguru import logger
from utils.data_utils import get_nx_graph, fetch_tudataset_graphs
from utils.utils import set_seed
import datetime
import tqdm
from torch_geometric.data import Data
import torch
from pebble import ProcessPool
from multiprocessing import Pool


def create_pyG_object(graph):
    features = torch.ones(graph.number_of_nodes(), 1, dtype=torch.float)
    edges = list(graph.edges)
    doubled_edges = [[x, y] for (x, y) in edges] + [[y, x] for (x, y) in edges]
    edge_index = torch.tensor(np.array(doubled_edges).T, dtype=torch.int64)
    return Data(x = features, edge_index = edge_index)



def to_pyged(graph):
    g  = create_pyG_object(graph)
    return (torch.argmax(g.x, dim=1).tolist(), list(zip(*g.edge_index.tolist())))


def wrapper_so_that_pebble_can_pickle(pair):
    return ged.calcged(*pair)

def run_parallel_pool(func,input_list):
    all_result = []
    len_list = len(input_list)
    with ProcessPool(max_workers=100) as pool:
        future = pool.map(func, input_list, timeout=30000)
        iterator = future.result()
        for c_i in tqdm.tqdm(range(len_list)): 
              try:
                  result = next(iterator)
                  print(f'Result for {c_i}th pair: {result}')
                  all_result.append(result)
              except StopIteration:
                  break
              except TimeoutError as error:  
                all_result.append(None) 
              except AssertionError: 
                  all_result.append(None)
    return all_result

def run_parallel_pool2(func,input_list):
    
    pool = Pool()
    all_result = list(tqdm.tqdm(pool.imap(func,input_list)))
    return all_result



def check_subiso(data):
  gc,gq = data
  return nx.algorithms.isomorphism.GraphMatcher(gc, gq).subgraph_is_isomorphic()


def run_parallel_pool(func,input_list):
    all_result = []
    len_list = len(input_list)
    
    with ProcessPool(max_workers=100) as pool:
        future = pool.map(func, input_list, timeout=30)
        iterator = future.result()

        for c_i in range(len_list):
              try:
                  result = next(iterator)
                  all_result.append(result)

              except StopIteration:
                  break
              except TimeoutError as error:
                  all_result.append(None)
              except AssertionError:
                  all_result.append(None)
    return all_result

# def check_iso(data):
#     gc,gq = data
#     return iso.GraphMatcher(gc, gq).subgraph_is_isomorphic()

if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    logger.info(f"cli_conf.dataset.rel_mode = {cli_conf.dataset.rel_mode}")
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    # model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    # conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(main_conf, data_conf, cli_conf)
    conf.dataset.path = "ged_data"
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    
    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    # torch.backends.cuda.matmul.allow_tf32 = False
    # torch.backends.cudnn.allow_tf32 = False

    
    # task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    # logger.info(task_name)
    # conf.task_name = task_name
    
    set_seed(conf.training.seed, conf)
    
    
    graphs = fetch_tudataset_graphs(conf)


    # no_of_query_subgraphs = 300
    # no_of_corpus_subgraphs = 800
    
    
    no_of_corpus_subgraphs = 10000
    aug_num_cgraphs = 10 * no_of_corpus_subgraphs

    directory = f"{conf.base_dir}{conf.dataset.path}/{conf.dataset.name}/preprocessed"
    if not os.path.exists(directory):
        os.makedirs(directory)
    c_fname = f"{directory}/{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}.pkl"
    # print(directory, c_fname)

    if os.path.isfile(c_fname):
        (subgraph_list,anchor_list,subgraph_id_list ) = pickle.load(open(c_fname,"rb"))
        logger.info(f"Loaded {len(subgraph_list)} corpus subgraphs from {c_fname}")
    else:
        logger.info("Sampling corpus subgraphs")
        subgraph_list, anchor_list, subgraph_id_list = [], [], []
        subgraph_sampler = OnTheFlySubgraphSampler(graphs,conf.dataset.min_graph_nodes,conf.dataset.max_graph_nodes)
        for i in tqdm.tqdm(range(no_of_corpus_subgraphs)):
            while True:
                dup_flag = False
                sgraph,anchor,sgraph_id = subgraph_sampler.sample_subgraph()
                for c_graph in subgraph_list:
                    if nx.is_isomorphic(sgraph,c_graph):
                        dup_flag = True
                        break
                if dup_flag:
                    # logger.info("corpus graph Discarded due to duplicacy")
                    continue
                else:
                    logger.info(f"corpus graph {i} sampled")
                subgraph_list.append(sgraph)
                anchor_list.append(anchor)
                subgraph_id_list.append(sgraph_id)
                break
        
        pickle.dump((subgraph_list,anchor_list,subgraph_id_list),open(c_fname,"wb"))
        logger.info(f"Saved {len(subgraph_list)} corpus subgraphs to {c_fname}")
        

    cfname_aug = f"{directory}/{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_aug_with_dup.pkl"

    if os.path.isfile(cfname_aug):
        new_cgrs = pickle.load(open(cfname_aug,"rb"))
        logger.info(f"Loaded {len(new_cgrs)} augmented corpus subgraphs with duplicates from {cfname_aug}")
    else:    
        new_cgrs = []
        for nn in [1,2]:
            for ne in [3,4,5]:
                for idx in tqdm.tqdm(range(len(subgraph_list))):
                    for i in range(5):
                        #for i in range(400//len(qgraphs_10)):
                        new_cgraph = random_graph_generator(subgraph_list[idx],num_new_nodes=nn, max_edges_per_node=ne,a=0,b=2)

                        new_cgrs.append(new_cgraph)
    
        pickle.dump(new_cgrs,open(cfname_aug,"wb"))
    
    cfname_aug = f"{directory}/{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_aug_without_dup.pkl"

    if os.path.isfile(cfname_aug):
        unique_cgrs = pickle.load(open(cfname_aug,"rb"))
        logger.info(f"Loaded {len(unique_cgrs)} augmented corpus subgraphs without duplicates from {cfname_aug}")
    
    else:

        unique_cgrs = []
        cnt=0
        for g in tqdm.tqdm(new_cgrs):
            dup  = False
            for g1 in unique_cgrs:
                if nx.is_isomorphic(g,g1):
                    dup = True
                    break
            if not dup:
                unique_cgrs.append(g)
                cnt+=1
                print(f'\r num_unique = {cnt}', end= ' ')
            if cnt%1000 == 0:
                logger.info(f"num_unique = {cnt} Saving now")
                pickle.dump(unique_cgrs,open(cfname_aug,"wb"))
                
            if cnt == aug_num_cgraphs:
                break
            
            
        pickle.dump(unique_cgrs,open(cfname_aug,"wb"))
        
    
    assert len(unique_cgrs) >= aug_num_cgraphs, print(f"dataset: {cli_conf.dataset.name} len(unique_cgrs) = {len(unique_cgrs)} < aug_num_cgraphs = {aug_num_cgraphs}")
    
    # #TODO IR: obtain first 100K and proceed with remaining query and groundtruth generation .
    
    final_cgraphs = unique_cgrs[0:aug_num_cgraphs]
    logger.info(f"retaining {len(final_cgraphs)} corpus subgraphs for query generation")
    
    final_cgr_fname = f"{directory}/{aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_relabeled.pkl"
    if os.path.isfile(final_cgr_fname):
        #load
        c_gr_all = pickle.load(open(final_cgr_fname,"rb"))
        logger.info(f"Loaded {len(c_gr_all)}  final corpus relabeled subgraphs without duplicates from {final_cgr_fname}")
    
    else:
        # Relabel the nodes from 0 --> n-1 (needed earier for pyged )
        c_gr_all = [nx.convert_node_labels_to_integers(o) for o in final_cgraphs]
        pickle.dump(c_gr_all,open(final_cgr_fname,"wb"))
        logger.info(f"Saved final relabeled corpus to {final_cgr_fname}")
        

        
    qgr_fname = f"{directory}/all_queries_for_{no_of_corpus_subgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}_aug_without_dup.pkl"
    sample_more_queries  = False
    if os.path.isfile(qgr_fname):
        all_qgrs,gt_dict = pickle.load(open(qgr_fname,"rb"))
        # n_queries = 0
    #     n_queries_within_ratio = 0 
        # for r in all_gt_ratios:
    #         if r > conf.dataset.MinR and r  < conf.dataset.MaxR:
    #             n_queries_within_ratio += 1
    #         n_queries = n_queries + 1
        n_queries = len(all_qgrs)
        logger.info(f"Loaded {len(all_qgrs)} query graphs without duplicates from {qgr_fname} ")
    else:
        all_qgrs = []
        gt_dict  = {} 
        n_queries = 0

        
    # if n_queries_within_ratio <500: 
    if n_queries<500:
        logger.info(f"Sampling more queries because {n_queries} queries which is < 500")
        sample_more_queries = True

    
    assert sample_more_queries == False, "Sampling more queries needed."
    
    
    #re init for unequal cost gt
    n_queries = 0
    gt_dict = {}
    final_uneq_cost_gt_fname = f"{directory}/uneq_cost_gt_dict_for_{aug_num_cgraphs}_corpus_subgraphs_min{conf.dataset.min_graph_nodes}_max{conf.dataset.max_graph_nodes}.pkl"

    if os.path.isfile(final_uneq_cost_gt_fname):
        gt_dict = pickle.load(open(final_uneq_cost_gt_fname,"rb"))
        n_queries = len(gt_dict)
        logger.info(f"Found {n_queries} queries in the final unequal cost gt dict. Starting from {n_queries+1}th query")
        n_queries = n_queries + 1
        

    
    while n_queries < 500:
        s = time.time()
        new_sgraph = all_qgrs[n_queries]
        gedlib_cgraphs = list(map(to_pyged, c_gr_all))
        gedlib_qgraph = [to_pyged(new_sgraph)]
       
        node_ins_cost = 0.
        node_del_cost = 0.
        edge_ins_cost = 1.
        edge_del_cost = 2.
        cost = [node_ins_cost, node_del_cost, 0., edge_ins_cost, edge_del_cost, 0.]
       
        gedlib_graphs = gedlib_cgraphs + gedlib_qgraph
        num_graphs = len(gedlib_cgraphs)
        index_combinations = [(num_graphs,i) for i in range(num_graphs)]
        
        ged = pyged.GraphEditDistanceCalculator(gedlib_graphs, 'f2', [f'--threads 1 --time-limit 1000'], cost)
        result = run_parallel_pool2(wrapper_so_that_pebble_can_pickle, index_combinations)
        lb, ub, times = list(zip(*result))
        lb, ub = torch.tensor(lb), torch.tensor(ub)
        num_not_equal = torch.sum(lb != ub)
        
        if not (num_not_equal == 0):
            logger.info(f'SAD: Found {num_not_equal} Number of discrepancies')

        gt_dict[n_queries] = ub
        n_queries = n_queries + 1
        logger.info(f"Precessed #{n_queries} query, time taken {time.time()-s}")
        pickle.dump(gt_dict,open(final_uneq_cost_gt_fname,"wb"))
        
    #dump again for the remenants
    pickle.dump(gt_dict,open(final_uneq_cost_gt_fname,"wb"))
    logger.info(f"Saved final unequal cost gt dict to {final_uneq_cost_gt_fname}")

    
        
