import networkx as nx
import os
import glob
import json
import matplotlib.pyplot as plt
import pickle
import random



GRAPH_LABELS_PATH = "original_graph_labels/"
EDGES_JSON_PATH = "/edges.json"
VERTICES_JSON_PATH = "/vertices.json"

def getRandomColor():
	r = lambda: random.randint(0,255)
	return '#%02X%02X%02X' % (r(),r(),r())

def getDirectories():
	return glob.glob(GRAPH_LABELS_PATH + "*")

def getGraphFromJSON(verticesJSON, edgesJSON, nodeLabels, edgeLabels, nodeCount, edgeCount):
	G = nx.Graph()
	edges = []
	vertices = []
	with open(edgesJSON, 'r') as e, open(verticesJSON, 'r') as v:
		for rawEdge in e:
			edge = json.loads(rawEdge)
			edges.append(edge)

		for rawVertex in v:
			vertex = json.loads(rawVertex)
			vertices.append(vertex)

	print("# vertices:", len(vertices), "and # edges:", len(edges))
	nodeCounter = nodeCount
	for vx in vertices:
		lbl = vx['label']
		G.add_node(vx['id'], label=lbl)
		if lbl not in nodeLabels:
			nodeLabels[lbl] = getRandomColor()
			nodeCounter += 1

	edgeCounter = edgeCount
	for edge in edges:
		lbl = edge['label']
		G.add_edge(edge['source'], edge['target'], label=lbl)
		if lbl not in edgeLabels:
			edgeLabels[lbl] = getRandomColor()
			edgeCounter += 1

	return G, nodeCounter, edgeCounter

def calcAverageDegree(G):
	tups = G.degree(G.nodes)
	degs = [n for _, n in tups]
	avg = sum(n for _, n in tups) / len(tups)
	return avg, degs

def plotGraphs(nodeLabels, edgeLabels, G, d):
	nodeColors = [nodeLabels[lbl['label']] for node, lbl in G.nodes(data=True)]
	edgeColors = [edgeLabels[lbl['label']] for target, source, lbl in G.edges(data=True)]
	print("This is from ", d)
	print(len(G.nodes))
	print(len(G.edges))
	if len(G.nodes) == len(G.edges) + 1:
		print("This graph is a tree.")
	else:
		print("This graph is not a tree.")
	labels=dict((n,d['label']) for n,d in G.nodes(data=True))
	nx.draw(G, pos=nx.spring_layout(G), node_size = 3, node_color=nodeColors, edge_color=edgeColors, labels=labels,\
			font_size=2, width=0.2)
	plt.savefig('graph_visualization/graph_labels_' + d + '.png', dpi=500)
	plt.close()
	plt.clf()

	nx.draw(G, pos=nx.spring_layout(G), node_size = 1.5, node_color=nodeColors, edge_color=edgeColors,
			font_size=2, width=0.2)
	plt.savefig('graph_visualization/graph_' + d + '.png', dpi=500)
	plt.close()
	plt.clf()

def main():
	directories = getDirectories()
	labelToGraph = {}
	edgeLabels = {}
	nodeLabels = {}
	nodeCount, edgeCount = 0, 0
	for d in directories:
		print("directory", d)
		edgesJSON = glob.glob(d + EDGES_JSON_PATH)
		verticesJSON = glob.glob(d + VERTICES_JSON_PATH)
		labelToGraph[d.split('/')[1]], nodeCount, edgeCount = getGraphFromJSON(verticesJSON[0], edgesJSON[0], nodeLabels, edgeLabels, nodeCount, edgeCount)

		print(d)
	file = open('data/siemens_graphs_dic.pkl', 'wb')

	# dump information to that file
	pickle.dump(labelToGraph, file)

	# close the file
	file.close()

	# G = labelToGraph["original_graph_labels/15"]

	# labels = dict((n,d['label']) for n,d in G.nodes(data=True))
	# print(labels)
	# nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
	# nx.draw_networkx_edges(G, pos=nx.spring_layout(G))

	# for d, G in labelToGraph.items():
	# 	plotGraphs(nodeLabels, edgeLabels, G, d)
	# x = []
	# for d in labelToGraph.keys():
	# 	G = labelToGraph[d]
	# 	G_avg, G_degs = calcAverageDegree(G)
	# 	x.append(G_avg)
		
	

	# print(d)
	# print(len(G_degs))
	# plt.hist(G_degs, normed=True, bins=5)
	# plt.ylabel('Average Degree')
	# plt.show()

	# plt.hist(x, normed=True, bins=22)
	# plt.ylabel('Average Degree')
	# plt.show()
	

if __name__ == '__main__':
	main()