import networkx as nx
import numpy as np
import pickle
import json
#from tensorflow.compat.v1 import gfile
import scipy as sc
import os
import re

from collections import defaultdict
from heuristics.heuristic_subgraph_matching import findSubgraphGT, toGT
from networkx.algorithms.isomorphism import GraphMatcher
import graph_tool as gt


def read_graphfile(datadir, dataname, max_nodes=None, min_feat_dim=100):
    ''' Read data from https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets
        graph index starts with 1 in file

    Returns:
        List of networkx objects with graph and node labels
    '''
    prefix = os.path.join(datadir, dataname, dataname)
    filename_graph_indic = prefix + '_graph_indicator.txt'
    # index of graphs that a given node belongs to
    graph_indic={}
    with open(filename_graph_indic) as f:
        i=1
        for line in f:
            line=line.strip("\n")
            graph_indic[i]=int(line)
            i+=1

    filename_nodes=prefix + '_node_labels.txt'
    node_labels=[]
    min_label_val = None
    try:
        with open(filename_nodes) as f:
            has_zero = False
            for line in f:
                line=line.strip("\n")
                l = int(line)
                node_labels+=[l]
                if min_label_val is None or min_label_val > l:
                    min_label_val = l
        # assume that node labels are consecutive
        num_unique_node_labels = max(node_labels) - min_label_val + 1
        node_labels = [l - min_label_val for l in node_labels]
    except IOError:
        print('No node labels')
    print("NUM NODE TYPES: ", num_unique_node_labels)
    '''
    filename_node_attrs=prefix + '_node_attributes.txt'
    node_attrs=[]
    try:
        with open(filename_node_attrs) as f:
            for line in f:
                line = line.strip("\s\n")
                attrs = [float(attr) for attr in re.split("[,\s]+", line) if not attr == '']
                node_attrs.append(np.array(attrs))
    except IOError:
        print('No node attributes')
    '''

    label_has_zero = False
    filename_graphs=prefix + '_graph_labels.txt'
    graph_labels=[]

    label_vals = []
    with open(filename_graphs) as f:
        for line in f:
            line=line.strip("\n")
            val = int(line)
            if val not in label_vals:
                label_vals.append(val)
            graph_labels.append(val)

    label_map_to_int = {val:i for i, val in enumerate(label_vals)}
    graph_labels = np.array([label_map_to_int[l] for l in graph_labels])

    filename_adj=prefix + '_A.txt'
    adj_list={i:[] for i in range(1,len(graph_labels)+1)}
    index_graph={i:[] for i in range(1,len(graph_labels)+1)}
    num_edges = 0
    with open(filename_adj) as f:
        for line in f:
            line=line.strip("\n").split(",")
            e0,e1=(int(line[0].strip(" ")),int(line[1].strip(" ")))
            adj_list[graph_indic[e0]].append((e0,e1))
            index_graph[graph_indic[e0]]+=[e0,e1]
            num_edges += 1
    for k in index_graph.keys():
        index_graph[k]=[u-1 for u in set(index_graph[k])]

    graphs=[]
    for i in range(1,1+len(adj_list)):
        # indexed from 1 here
        G=nx.from_edgelist(adj_list[i])
        if max_nodes is not None and G.number_of_nodes() > max_nodes:
            continue

        # add features and labels
        G.graph['feat'] = graph_labels[i-1]
        for u in G.nodes():
            if len(node_labels) > 0:
                node_label_one_hot = [0] * num_unique_node_labels
                node_label = node_labels[u-1]
                node_label_one_hot[node_label] = 1
                if num_unique_node_labels < min_feat_dim:
                    feat = np.concatenate([np.array(node_label_one_hot), np.ones(min_feat_dim - num_unique_node_labels, dtype=float)])
                G.nodes[u]['feat'] = feat
            # Not using node attrs
            #if len(node_attrs) > 0:
            #    G.node[u]['feat'] = node_attrs[u-1]
        #if len(node_attrs) > 0:
         #   G.graph['feat_dim'] = node_attrs[0].shape[0]

        # relabeling
        mapping={}
        it=0
        if float(nx.__version__)<2.0:
            for n in G.nodes():
                mapping[n]=it
                it+=1
        else:
            for n in G.nodes:
                mapping[n]=it
                it+=1

        # indexed from 0
        graphs.append(nx.relabel_nodes(G, mapping))
    return graphs

def create_features(graphs, feat_type='onehot'):
    key_dicts = {}
    remove_nodes = defaultdict(list)
    for i, G in enumerate(graphs):
        for node in G.nodes():
            if (isinstance(G, nx.classes.graph.Graph) and len(list(G.edges(node))) == 0) or (isinstance(G, nx.classes.digraph.DiGraph) and len(G.out_edges(node)) == 0):
                remove_nodes[i].append(node)
                continue
            # key = node.split(':')[0]
            key = G.nodes[node]['label'] if 'label' in G.nodes[node] else node
            if not key in key_dicts:
                key_dicts[key] = len(key_dicts)
            G.nodes[node]['type_index'] = key_dicts[key]
    feat_dim = len(key_dicts)

    for idx, remove_node_list in remove_nodes.items():
        graphs[idx].remove_nodes_from(remove_node_list)

    new_graphs = []
    for G in graphs:
        for node in G.nodes():
            if feat_type == 'onehot':
                feat = np.zeros(feat_dim, dtype=float)
                feat[G.nodes[node]['type_index']] = 1
            elif feat_type == 'const':
                feat = np.ones(100, dtype=float)
            else:
                return NotImplementedError
            G.nodes[node]['feat'] = feat
        new_graphs.append(nx.convert_node_labels_to_integers(G))

    return new_graphs

def read_siemens():
    graphs = pickle.load(open('data/siemens_arch_graphs.pkl', 'rb'))

    return create_features(graphs)

def read_siemens_graphs():
    graphs = pickle.load(open('data/siemens_arch_graphs.pkl', 'rb'))
    queries = pickle.load(open('data/siemens_graphlets_dict', 'rb'))

    key_dicts = {}
    for G in graphs:
        for node in G.nodes():
            #key = node.split(':')[0]
            key = G.nodes[node]['label']
            if not key in key_dicts:
                key_dicts[key] = len(key_dicts)
            G.nodes[node]['type_index'] = key_dicts[key]
    feat_dim = len(key_dicts)
    const_dim = 20
    for G in graphs:
        for node in G.nodes():
            feat = np.concatenate([np.zeros(feat_dim, dtype=float), np.ones(const_dim, dtype=float)])
            feat[G.nodes[node]['type_index']] = 1
            G.nodes[node]['feat'] = feat

    for sg, _ in queries:
        for node in sg.nodes():
            key = sg.nodes[node]['label']
            if not key in key_dicts:
                raise RuntimeException('Unknown node label in graphlets.')
            sg.nodes[node]['type_index'] = key_dicts[key]
            feat = np.concatenate([np.zeros(feat_dim, dtype=float), np.ones(const_dim, dtype=float)])
            feat[sg.nodes[node]['type_index']] = 1
            sg.nodes[node]['feat'] = feat
    for sg, _ in queries:
        for node in sg.nodes():
            feat = np.concatenate([np.zeros(feat_dim, dtype=float), np.ones(const_dim, dtype=float)])
            feat[sg.nodes[node]['type_index']] = 1
            sg.nodes[node]['feat'] = feat

    return graphs, queries

def graph_stats(graphs):
    num_graphs = len(graphs)
    avg_size = sum([g.number_of_nodes() for g in graphs]) / num_graphs
    avg_size_std = np.std([g.number_of_nodes() for g in graphs])
    avg_deg = sum([g.number_of_edges() for g in graphs]) / num_graphs
    avg_deg_std = np.std([g.number_of_nodes() for g in graphs])
    avg_clustering = sum([nx.average_clustering(g) for g in graphs]) / num_graphs
    if nx.is_directed(graphs[0]):
        avg_scc = sum([nx.number_strongly_connected_components(g) for g in graphs]) / num_graphs
        und_graphs = [g.to_undirected() for g in graphs]
        avg_comp = sum([nx.number_connected_components(g) for g in und_graphs]) / num_graphs
    else:
        avg_comp = sum([nx.number_connected_components(g) for g in graphs]) / num_graphs

    print('Avg size: ', avg_size, ' std: ', avg_size_std)
    print('Avg deg: ', avg_deg, ' std: ', avg_deg_std)
    print('Avg clustering: ', avg_clustering)
    print('Avg num components: ', avg_comp)
    if nx.is_directed(graphs[0]):
        print('Avg num SCC: ', avg_scc)

def read_arqui(path='data/nn_code_sample'):
    directory = os.fsencode(path)

    graphs = []
    test_query_graphs = []
    for file in sorted(os.listdir(directory)):
        filename = os.fsdecode(file)
        if filename.endswith(".gpickle"):
            fpath = os.path.join(path, filename)
            graph = nx.read_gpickle(fpath).to_undirected()
            #graph = pickle.load(open(path + '/' + filename, 'rb'))
            graphs.append(graph)
        if filename.endswith(".query"):
            fpath = os.path.join(path, filename)
            graph = nx.read_gpickle(fpath).to_undirected()
            #test_query_graphs.append(graph)
            comp = max(nx.connected_component_subgraphs(graph), key=len)
            test_query_graphs.append(comp)

    graph_stats(graphs)
    graph_stats(test_query_graphs)
    num_graphs = len(graphs)
    concat = create_features(graphs + test_query_graphs, feat_type='onehot')

    return concat[:num_graphs], concat[num_graphs:]

def create_linegraph_features(orig_graph, new_graph, num_classes, const_dim=10):
    remove = []
    for node in new_graph.nodes():
        data = orig_graph.edges[node]
        feat = np.zeros(num_classes, dtype=float)
        if data['e_label'].item() > 3:
            remove.append(node)
        else:
            feat[data['e_label'].item()] = 1
            new_graph.nodes[node]['feat'] = np.concatenate([feat, np.ones(const_dim, dtype=float)])
    print(len(remove))
    new_graph.remove_nodes_from(remove)
    return nx.convert_node_labels_to_integers(new_graph)

def read_WN(path='data/WN18.gpickle'):
    graph = nx.read_gpickle(path).to_undirected()
    new_graph = nx.line_graph(graph)
    print(nx.number_of_nodes(new_graph))
    return [create_linegraph_features(graph, new_graph, 18)]

def read_ppi(dataset_path='data', dataset_str='ppi'):
    graph_json = json.load(open('{}/{}/{}-G.json'.format(dataset_path, dataset_str, dataset_str)))
    graph_nx = nx.readwrite.json_graph.node_link_graph(graph_json)
    graphs = create_features([graph_nx], 'const')
    print(nx.number_of_nodes(graph_nx))
    return graphs

if __name__ == '__main__':
    import os
    import matplotlib.pyplot as plt

    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    read_ppi()
