import pandas as pd
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def plot_membership(D, field, members, title, ax, col):
	ax.set_title(title)
	frequencies = [len(D[D[field].str.contains(member)]) for member in members]
	ax.bar(members, frequencies, color = col)	
	for i in range(len(members)):
		txt = str(round(frequencies[i] * 100 / len(D),4)) + "%"
		ax.text(i, max(frequencies) / 20, txt , va = "bottom", ha = "center", rotation =90)
	ax.set_xticklabels(members, rotation = 90)

def plot_category(D, field, title, ax, col):
	ax.set_title(title)
	categories = sorted(list([x for x in set(D[field]) if type(x) != float]))
	frequencies = [len(D[D[field] == cat]) for cat in categories]
	ax.bar(categories, frequencies, color = col)	
	for i in range(len(categories)):
		txt = str(round(frequencies[i] * 100 / len(D),4)) + "%"
		ax.text(i, max(frequencies) / 20, txt , va = "bottom", ha = "center", rotation =90)
	ax.set_xticklabels(categories, rotation = 90)

def plot_present_fields(D, fields, title, ax, col):
	ax.set_title(title)
	frequencies = [len(D[~ D[field].isnull()]) for field in fields]
	ax.bar(fields, frequencies, color = col)	
	for i in range(len(fields)):
		txt = str(round(frequencies[i] * 100 / len(D),4)) + "%"
		ax.text(i, max(frequencies) / 20, txt , va = "bottom", ha = "center", rotation =90)
	ax.set_xticklabels(fields, rotation = 90)

def plot_distribution(D, field, title, ax, col):
	field = [field]
	Q1 = D[field].quantile(0.1)
	Q3 = D[field].quantile(0.9)
	IQR = Q3 - Q1

	D = D[~((D[field] < (Q1 - 1.5 * IQR)) |(D[field] > (Q3 + 1.5 * IQR))).any(axis=1)]
	ax.set_title(title)
	ax.hist(D[field], bins = 100, color = col)
	

def analyze(dataset, what, label_field):
	file = "GraphGymPyG/datasets/%s/raw/%s_%s.csv" % (dataset, dataset, what)

	D = pd.read_csv(file)  

	if dataset == "imdb":
		D_labels = sorted(list(set([x for ml in D[label_field] for x in ml.split(":") if x != ""])))
	else:
		D_labels = sorted(list(set(D[label_field])))

	D_fields = [x for x in list(D.columns) if x != label_field]
	D_types = {f : d for (f,d) in zip(list(D.columns), list(D.dtypes))}
	D_number_fields = [("int" in str(D_types[field])) or ("float" in str(D_types[field])) for field in D_fields]
	D_category_fields = [len(set(D[D_fields[i]])) <= 10 and D_fields[i] != label_field for i in range(len(D_fields))]
	field_plots = sum(D_number_fields) + sum(D_category_fields)

	fig = plt.figure(constrained_layout=True, figsize = (9, 9 + 4.5 * (len(D_labels) + field_plots + 1)))
	gs = GridSpec(2 + len(D_labels) + field_plots + 1, 3, figure=fig)
	
	# All labels
	ax = fig.add_subplot(gs[0, :])
	if dataset == "imdb":
		plot_membership(D, label_field, D_labels, "Labels among all %s" % what, ax, "#3399FF")
	else:
		plot_category(D, label_field, "Labels among all %s" % what, ax, "#3399FF")

	# All categories
	ax = fig.add_subplot(gs[1, :])
	plot_present_fields(D, D_fields, "Presence of fields among all %s" % what, ax, "#FF9933")

	# Categories per label
	for i in range(len(D_labels)):
		label = D_labels[i]
		ax = fig.add_subplot(gs[2 + i, :])
		if dataset == "imdb":
			Dprime = D[D[label_field].str.contains(label)]
		else:
			Dprime = D[D[label_field] == label]
		plot_present_fields(Dprime, D_fields, "Presence of fields among %s with label \"%s\"" % (what, label), ax, "#33FF99")

	# Number fields
	idx = 2 + len(D_labels)
	for i in range(len(D_fields)):
		field = D_fields[i]
		if D_number_fields[i]:
			ax = fig.add_subplot(gs[idx, :])
			plot_distribution(D, field, "Distribution of Field \"%s\"" % D_fields[i], ax, "#339999")
			idx += 1
		elif D_category_fields[i]:
			ax = fig.add_subplot(gs[idx, :])
			plot_category(D, field, "Values of the field \"%s\"" % field, ax, "#339999")
			idx += 1

	plt.savefig("analysis/analysis_%s_%s.pdf" % (dataset, what))

analyze(sys.argv[1], "vertices", "labels")
analyze(sys.argv[1], "edges", "type")
			

