import networkx as nx
import os
import glob
import json
import matplotlib.pyplot as plt
import pickle
import random
import numpy as np

LOCAL_INPUT_VARIABLE = 'LocalInputVariable'
PLCBSG = 'PlcBlockSystemGroup'
OB = 'OB'
O = 'O'
WIRE = 'Wire'
CONTACT = 'Contact'
NETWORK_SOURCE = 'NetworkSource'
GLOBALDB = 'GlobalDB'
LOCAL_RETURN_VAR = 'LocalReturnVariable'
FC = 'FC'


def createGraphlets():
    graphlets = []

    G = nx.Graph()
    G.add_node('liv1', label=LOCAL_INPUT_VARIABLE)
    G.add_node('liv2', label=LOCAL_INPUT_VARIABLE)
    G.add_node('plcbsg', label=PLCBSG)
    G.add_node('ob1', label=OB)
    G.add_node('ob2', label=OB)
    G.add_edge('ob1', 'liv1')
    G.add_edge('ob1', 'liv2')
    G.add_edge('ob1', 'plcbsg')
    G.add_edge('ob2', 'liv1')
    G.add_edge('ob2', 'liv2')
    G.add_edge('ob2', 'plcbsg')

    graphlets.append([G, [0]])

    G = nx.Graph()
    G.add_node('contact', label=CONTACT)
    G.add_node('nxsource', label=NETWORK_SOURCE)
    G.add_node('plcbsg', label=PLCBSG)
    G.add_node('wire1', label=WIRE)
    G.add_node('wire2', label=WIRE)
    G.add_node('wire3', label=WIRE)
    G.add_edge('contact', 'wire1')
    G.add_edge('contact', 'wire2')
    G.add_edge('contact', 'wire3')
    G.add_edge('nxsource', 'wire1')
    G.add_edge('nxsource', 'wire2')
    G.add_edge('nxsource', 'wire3')

    graphlets.append([G, [1,28]])

    G = nx.Graph()
    G.add_node('o', label=O)
    G.add_node('nxsource', label=NETWORK_SOURCE)
    G.add_node('plcbsg', label=PLCBSG)
    G.add_node('wire1', label=WIRE)
    G.add_node('wire2', label=WIRE)
    G.add_node('wire3', label=WIRE)
    G.add_edge('o', 'wire1')
    G.add_edge('o', 'wire2')
    G.add_edge('o', 'wire3')
    G.add_edge('nxsource', 'wire1')
    G.add_edge('nxsource', 'wire2')
    G.add_edge('nxsource', 'wire3')

    graphlets.append([G, [24]])

    G = nx.Graph()
    G.add_node('o', label=O)
    G.add_node('nxsource', label=NETWORK_SOURCE)
    G.add_node('globalDB', label=GLOBALDB)
    G.add_node('wire1', label=WIRE)
    G.add_node('wire2', label=WIRE)
    G.add_node('wire3', label=WIRE)
    G.add_node('wire4', label=WIRE)
    G.add_edge('globalDB', 'wire1')
    G.add_edge('globalDB', 'wire2')
    G.add_edge('globalDB', 'wire3')
    G.add_edge('globalDB', 'wire4')
    G.add_edge('nxsource', 'wire1')
    G.add_edge('nxsource', 'wire2')
    G.add_edge('nxsource', 'wire3')
    G.add_edge('nxsource', 'wire4')

    graphlets.append([G, [11]])

    G = nx.Graph()
    G.add_node('lclreturn', label=LOCAL_RETURN_VAR)
    G.add_node('plcbsg', label=PLCBSG)
    G.add_node('fc1', label=FC)
    G.add_node('fc2', label=FC)
    G.add_node('fc3', label=FC)
    G.add_edge('plcbsg', 'fc1')
    G.add_edge('plcbsg', 'fc2')
    G.add_edge('plcbsg', 'fc3')
    G.add_edge('lclreturn', 'fc1')
    G.add_edge('lclreturn', 'fc2')
    G.add_edge('lclreturn', 'fc3')

    graphlets.append([G, [16]])

    G = nx.Graph()
    G.add_node('lclinput', label=LOCAL_INPUT_VARIABLE)
    G.add_node('plcbsg', label=PLCBSG)
    G.add_node('ob1', label=OB)
    G.add_node('ob2', label=OB)
    G.add_edge('plcbsg', 'ob1')
    G.add_edge('plcbsg', 'ob2')
    G.add_edge('lclinput', 'ob1')
    G.add_edge('lclinput', 'ob2')

    graphlets.append([G, [16]])

    G = nx.Graph()
    G.add_node('lclinput1', label=LOCAL_INPUT_VARIABLE)
    G.add_node('lclinput2', label=LOCAL_INPUT_VARIABLE)
    G.add_node('lclreturn', label=LOCAL_RETURN_VAR)
    G.add_node('fc1', label=FC)
    G.add_node('fc2', label=FC)
    G.add_edge('fc1', 'lclinput1')
    G.add_edge('fc1', 'lclinput2')
    G.add_edge('fc1', 'lclreturn')
    G.add_edge('fc2', 'lclinput1')
    G.add_edge('fc2', 'lclinput2')
    G.add_edge('fc2', 'lclreturn')

    graphlets.append([G, [11]])
    
    return graphlets

# def read_siemens_subgraphs(graphs, key_dicts, feat_dim):
#     for G, _ in graphs:
#         for node in G.nodes():
#             #key = node.split(':')[0]
#             key = G.nodes[node]['label']
#             if not key in key_dicts:
#                 print("WE BROKEN")
#                 print(key)
#                 key_dicts[key] = len(key_dicts)
#             G.nodes[node]['type_index'] = key_dicts[key]
#     for G, _ in graphs:
#         for node in G.nodes():
#             feat = np.zeros(feat_dim, dtype=float)
#             feat[G.nodes[node]['type_index']] = 1
#             G.nodes[node]['feat'] = feat

def read_siemens_graphs(graphlets, save=False):
    graphs = pickle.load(open('data/siemens_graphs_dic.pkl', 'rb'))
    key_dicts = {}
    for _, G in graphs.items():
        for node in G.nodes():
            #key = node.split(':')[0]
            key = G.nodes[node]['label']
            if not key in key_dicts:
                key_dicts[key] = len(key_dicts)
            G.nodes[node]['type_index'] = key_dicts[key]
    feat_dim = len(key_dicts)
    for _, G in graphs.items():
        for node in G.nodes():
            feat = np.zeros(feat_dim, dtype=float)
            feat[G.nodes[node]['type_index']] = 1
            G.nodes[node]['feat'] = feat
    
    for sg, _ in graphlets:
        for node in sg.nodes():
            #key = node.split(':')[0]
            key = sg.nodes[node]['label']
            if not key in key_dicts:
                key_dicts[key] = len(key_dicts)
            sg.nodes[node]['type_index'] = key_dicts[key]
    for sg, _ in graphlets:
        for node in sg.nodes():
            feat = np.zeros(feat_dim, dtype=float)
            feat[sg.nodes[node]['type_index']] = 1
            sg.nodes[node]['feat'] = feat

    file = open('data/siemens_arch_graphs_w_features.pkl', 'wb')
    # dump information to that file
    pickle.dump(graphs, file)
    # close the file
    file.close()

    file = open('data/siemens_graphlets_w_features.pkl', 'wb')
    # dump information to that file
    pickle.dump(graphlets, file)
    # close the file
    file.close()

def plotGraphs(G, idx):
    labels=dict((n,d['label']) for n,d in G.nodes(data=True))
    nx.draw(G, pos=nx.spring_layout(G), node_size = 30, labels=labels,\
			font_size=12, width=1)
    plt.tight_layout()
    plt.savefig('sub_graph_labels' + str(idx) + '.png', dpi=400)
    plt.close()
    plt.clf()

def main():
    graphlets = createGraphlets()
    # for idx, g in enumerate(graphlets):
    #     plotGraphs(g[0], idx)

    # file = open('graphlets.dic', 'wb')

	# # dump information to that file
    # pickle.dump(graphlets, file)

	# # close the file
    # file.close()
    read_siemens_graphs(graphlets)

    #SANITY CHECK
    # for g, _ in graphlets:
    #     for d in g.nodes(data=True):
    #         print(d[1]['label'], d[1]['type_index'], len(d[1]['feat']))


    # # graphlets contains all the new information
    # file = open('graphlets.dic', 'wb')

	# # dump information to that file
    # pickle.dump(graphlets, file)

	# # close the file
    # file.close()


if __name__ == '__main__':
	main()