
import networkx.algorithms.isomorphism as iso
import pickle
import torch
from graph_tool.all import Graph
import graph_tool.topology as gtt
import signal

def load_graphs_with_features():
    file = open('data/siemens_arch_graphs_w_features.pkl', 'rb')
    # dump information to that file
    graphsDic = pickle.load(file)
    # close the file
    file.close()

    file = open('data/siemens_graphlets_w_features.pkl', 'rb')
    # dump information to that file
    subgraphs = pickle.load(file)
    # close the file
    file.close()

    return graphsDic, subgraphs

# t = time.time()
# GM = nx.algorithms.isomorphism.GraphMatcher(g,sg,node_match=nx.algorithms.isomorphism.categorical_node_match(['label'], []))
# print(time.time() - t)
# print()



def load_graphs_without_features():
    file = open('data/siemens_graphs_dic.pkl', 'rb')
    # dump information to that file
    graphsDic = pickle.load(file)
    # close the file
    file.close()

    file = open('data/siemens_graphlets_dict.pkl', 'rb')
    # dump information to that file
    subgraphs = pickle.load(file)
    # close the file
    file.close()

    return graphsDic, subgraphs

# only does induced subgraphs
def findSubgraphNx(graph, subgraph):
    GM = iso.GraphMatcher(graph, subgraph, node_match=iso.categorical_node_match(['label'], ['label']))
    return GM.subgraph_isomorphisms_iter()

def findSubgraphNxDriver(graphsDic, subgraphs):
    sgToGraphIndex = {} # map from subgraph to (graph #, indices)
    for idx, sg_dat in enumerate(subgraphs):
        sg, gl = sg_dat
        sgToGraphIndex[idx] = []
        print("is this ur subgraph", idx)
        print(sg.nodes(data=True))
        for g in graphsDic.keys():
            graph = graphsDic[g]
            matches = list(findSubgraphNx(graph, sg))
            if len(matches) > 0:
                sgToGraphIndex[idx] = (g, matches)
            print('found ' + str(len(matches)) + ' in graph ' + str(g))
    return sgToGraphIndex


# graph_tools uses the same VF2 algorithm, but works for any subgraph (not node induced) matching
def findSubgraphGT(graph, subgraph):
    def timeout(signum, frame):
        print('timeout')
        raise TimeoutError
    signal.signal(signal.SIGALRM, timeout)
    signal.alarm(30)
    vm = gtt.subgraph_isomorphism(subgraph, graph, vertex_label=(subgraph.vp['feat'], graph.vp['feat']), max_n=1)
    signal.alarm(0)
    return vm

def toGT(adj, feat):
    G = Graph()

    assert len(adj.shape) == 2
    if isinstance(adj, torch.Tensor):
        endpoints = adj.cpu().numpy().nonzero()
    else:
        endpoints = adj.nonzero()
    edges = []

    for i in range(len(endpoints[0])):
        edges.append((endpoints[0][i], endpoints[1][i]))
    G.add_edge_list(edges)
    feats = G.new_vertex_property('vector<int>')
    for idx in G.get_vertices():
        feats[idx] = feat[idx,:].tolist()
    G.vertex_properties['feat'] = feats
    return G

def main():
    graphsDic, subgraphs = load_graphs_without_features()
    # loop over subgraphs and graph list
    # findSubgraphNxDriver(graphsDic, subgraphs)
    sg = subgraphs[0][0]
    gtG = nx2gt(sg)
    # findSubgraphGTDriver(graphsDic, subgraphs)
    findSubgraphNxDriver(graphsDic, subgraphs)


# def match(G1, G2):
#     '''
#     This function compares two graphs of size 3 (number of nodes)
#     and checks if they are isomorphic.
#     It returns a boolean indicating whether or not they are isomorphic
#     '''
#     if G1.GetEdges() > G2.GetEdges():
#         G = G1
#         H = G2
#     else:
#         G = G2
#         H = G1
#     # Only checks 6 permutations, since k = 3
#     for p in permutations(range(3)):
#         edge = G.BegEI()
#         matches = True
#         while edge < G.EndEI():
#             if snot H.IsEdge(p[edge.GetSrcNId()], p[edge.GetDstNId()]):
#                 matches = False
#                 break
#             edge.Next()
#         if matches:
#             break
#     return matches

# def countIso(G, sg, verbose=False):
#     '''
#     Given a set of 3 node indices in sg, obtains the subgraph from the
#     original graph and renumbers the nodes from 0 to 2.
#     It then mathces this graph with one of the 13 graphs in
#     directed_3.
#     When it finds a match, it increments the motif_counts by 1 in the relevant
#     index

#     IMPORTANT: counts are stored in global motif_counts variable.
#     It is reset in the enumerate_subgraph method.
#     '''
#     if verbose:
#         print(sg)
#     nodes = snap.TIntV()
#     for NId in sg:
#         nodes.Add(NId)
#     # This call requires latest version of snap
#     SG = snap.GetSubGraphRenumber(G, nodes)
#     for i in range(len(directed_3)):
#         if match(directed_3[i], SG):
#             motif_counts[i] += 1

def enumerate_subgraph(G, k=3, verbose=False):
    '''
    This is the main function of the ESU algorithm.
    Here, you should iterate over all nodes in the graph,
    find their neighbors with ID greater than the current node
    and issue a recursive call to extend_subgraph

    A good idea would be to print a progress report on the cycle over nodes,
    So you get an idea of how long the algorithm needs to run
    '''
    motif_counts = {}
    ##########################################################################
    #TODO: Your code here
    j = 0.
    print_thresh = 0.05
    num_nodes = G.nodes()
    for node in G.nodes():
        sg = set()
        sg.add(node)
        v_ext = set()
        for nbr in list(G[node].keys()):
            if nbr > node:
                v_ext.add(nbr)
        extend_subgraph(G, k, sg, v_ext, node, verbose)
        j += 1.
        if j/num_nodes > print_thresh:
            print("Progress: {}%".format(100*j/num_nodes))
            print_thresh += 0.05
    ##########################################################################


def extend_subgraph(G, k, sg, v_ext, node_id, verbose=False):
    '''
    This is the recursive function in the ESU algorithm
    When you reach the base case,
    '''
    # Base case (you should not need to modify this):
    if len(sg) is k:
        print(sg)
        return
    # Recursive step:
    ##########################################################################
    #TODO: Your code here
    old_v_ext = v_ext.copy()
    while len(v_ext) > 0:
        w = v_ext.pop()
        new_v_ext = v_ext.copy()
        for nbr in list(G[w].keys()):
            if nbr > node_id and nbrNId not in sg and nbrNId not in old_v_ext:
                new_v_ext.add(nbrNId)
        sg.add(w)
        extend_subgraph(G, k, sg, new_v_ext, node_id, verbose)
        sg.remove(w)
    ##########################################################################

if __name__ == "__main__":
    main()
