import graphviz  
import argparse
from tqdm import tqdm
import networkx as nx
import numpy as np 
import pandas as pd 
from concurrent.futures import ThreadPoolExecutor

def read_dataset(base_path):
    test = pd.read_json(base_path+'-test.json')
    train = pd.read_json(base_path+'-train.json')
    valid = pd.read_json(base_path+'-valid.json')
    return test, train, valid


def draw_mind_map(subgraph, args):
    fig = graphviz.Digraph('machine-mind-map', format='png', node_attr={'shape': 'plaintext'}, graph_attr={'size':"20,20", 'fontname':'Times-Roman'})
    for edge in subgraph.edges:
        hid, tid, rid = edge
        h, t, r = args.entid2name.loc[hid], args.entid2name.loc[tid], args.relid2name.loc[rid]
        fig.edge(h, t, label=r)
    return fig

def generate_subgraph(triplet_id, args):
    hid, rid, tid = triplet_id
    subgraph = nx.MultiDiGraph()

    g = args.kg

    if hid in g and tid in g:
        simple_paths = nx.all_simple_edge_paths(g, hid, tid, cutoff=args.max_path_length)
        for path in simple_paths:
            subgraph.add_edges_from(path)
    
    if hid in g:
        subgraph.add_edges_from(nx.ego_graph(args.kg, hid, radius=args.max_hop).edges(keys=True))
    if tid in g:
        subgraph.add_edges_from(nx.ego_graph(args.kg, tid, radius=args.max_hop).edges(keys=True))

    if subgraph.number_of_nodes() > args.max_node_num:
        pagerank_score = nx.pagerank(subgraph, personalization={hid:1, tid:1})
        pruned_nodes = pd.Series(pagerank_score).sort_values(ascending=False).index.values[args.max_node_num:]
        pruned_nodes = list(pruned_nodes)
        subgraph.remove_nodes_from(pruned_nodes)
    
    return subgraph

def generate_mm_context(triplet_ids, args):
    subgraphs, figs = [], []
    with ThreadPoolExecutor(max_workers=16) as executor:
        subgraphs = [executor.submit(generate_subgraph, triplet_id, args) for triplet_id in tqdm(triplet_ids)]
        subgraphs = [sg.result() for sg in tqdm(subgraphs)]

        figs = [executor.submit(draw_mind_map, subgraph, args) for subgraph in subgraphs]
        figs = [fig.result() for fig in tqdm(figs)]

    return subgraphs, figs


def constrcut_mllm_dataset(df, args, filename='test'):
    subgraphs, figs = generate_mm_context(df.embedding_ids.values, args)
    
    def save_fig(i, fig):
        fname='%s-%s%i' % (args.dataset_name, filename, i)
        fig.render(filename=fname, directory='media_dir/seeKG', cleanup=True)
        return [fname+'.png',]

    fig_names = []
    with ThreadPoolExecutor(max_workers=16) as executor:
        fig_names = [executor.submit(save_fig, i, fig) for i, fig in tqdm(enumerate(figs))]
        fig_names = [fn.result() for fn in tqdm(fig_names)]

    fig_sources =[]
    for fig in tqdm(figs):
        fig_source = fig.source.split('\n',3)[-1].rsplit('}',1)[0]
        fig_sources.append('\nThe triplet information in the image:\n%s\n' % fig_source)

    df['image_file'] = fig_names
    df['img_source'] = fig_sources
    
    df['instruction'] = "Given the input knowledge graph triple in the form of (head, relation, tail), please first identify the key entities/relationships in the KG image, then cross-reference with your internal knowledge to determine the validity of the triple. Note: the image doesn't contain the input triple, and relation definitions may be looser than apparent. Return True if the probability ≥ 0.5, otherwise False. Respond strictly with 'True' or 'False' only. Do not include any additional text.\n"

    df['input'] = df['input'] + 'The input knowledge graph image: <image>\n'

    df.to_json(args.base_data_path+'-%smm.json' % filename, orient="records", indent=4)
        
    return df

def construct_seeKG_dataset(args):
    base_path = args.base_path
    base_data_path = args.base_data_path

    triplets = pd.read_csv(base_path+'train2id.txt',header=None, names=['hid','tid','rid'], sep=' ')

    kg = nx.MultiDiGraph()
    def add_edges(row):
        kg.add_edge(int(row.hid), int(row.tid), int(row.rid))
    triplets.apply(add_edges, axis=1)

    test, train, valid = read_dataset(base_data_path)

    relid2name = pd.read_csv(base_path+'relid2name.txt',sep='\t',index_col=False, header=None)
    entid2name = pd.read_csv(base_path+'entid2name.txt',sep='\t',index_col=False, header=None)
    relid2name = pd.Series(relid2name[1], index=relid2name[0])
    entid2name = pd.Series(entid2name[1], index=entid2name[0])
    
    args.relid2name = relid2name
    args.entid2name = entid2name

    # sample training data
    correct = train[train.output=='True'].index.values
    c_triplet = train[train.output=='True'].embedding_ids.values
    c_triplet = np.stack(c_triplet)

    # include all relations
    sampled = []
    hset, rset, tset = [], [], []
    for c, (h, r, t) in zip(correct, c_triplet):
        if r not in rset:
            sampled.append(c)
            rset.append(r)
            hset.append(h)
            tset.append(t)
    sampled_train = []     
    for c in sampled:
        sampled_train += [c, c+1, c+2]
    
    # random sample training data to reach max_samples
    max_samples = args.max_samples
    for c in np.random.permutation(correct):
        if len(sampled_train)>max_samples:
            break
        if c not in sampled:
            sampled_train += [c, c+1, c+2]

    # remove the training triples from current KG to simulate the test case
    for i, row in train.loc[sampled_train].iterrows():
        h, r, t = row.embedding_ids
        if row.output == 'True':
            kg.remove_edge(h,t,key=r)

    args.kg = kg

    constrcut_mllm_dataset(train.loc[np.random.permutation(sampled_train)], args, filename='train')
    constrcut_mllm_dataset(valid, args, filename='valid')
    constrcut_mllm_dataset(test, args, filename='test')





def load_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name',
                        type=str,
                        default='CoDeX-S',)
    parser.add_argument('--max_samples',
                        type=int,
                        default=10000,)
    parser.add_argument('--max_path_length',
                        type=int,
                        default=3,)
    parser.add_argument('--max_hop',
                        type=int,
                        default=1,)
    parser.add_argument('--max_node_num',
                        type=int,
                        default=6,)
    parser.add_argument('--max_edge_num',
                        type=int,
                        default=8,)
    
    return parser

if __name__ == "__main__":
    parser = load_args()
    args = parser.parse_args()
    args.base_path = './kg/%s/' % args.dataset_name
    args.base_data_path = './data/%s' % args.dataset_name

    construct_seeKG_dataset(args)