import pickle 
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Tahoma']
plt.rcParams.update({'font.size': 11})



def generate_census_partition_subfigure(fig, table, metrics, grouping, data_scales_list, abbreviated_metrics, title_pattern):
	# fig.rcParams["figure.figsize"] = (10, 5)
	x = [3600000 // scale for scale in data_scales_list]
	width = 0.35 / (2 * len(metrics) - 1)  
	x_pos = [i + 0.175 for i in range(len(x))]
	# fig.set_title("{} homogenization for ACS data grouping by {}".format(title_pattern, grouping))

	fig.set_xscale('log')

	bar_pos = 0
	for metric, metric_name, linestyle in zip(metrics, abbreviated_metrics, ['solid', 'dotted', 'dashed']):
		partition = 'disjoint'
		measurements = [table[scale][(partition, grouping)][metric] for scale in data_scales_list]
		means = [measurement['mean'] for measurement in measurements]
		stdevs = [measurement['stdev'] for measurement in measurements]
		x_pos = [i + width * bar_pos for i in range(len(x))]
		if grouping == 'individual':
			label = '{}'.format(partition)
		else:
			label = '{}, {}'.format(metric_name, partition)
		fig.errorbar(x, means,  color='red', linestyle = linestyle, label = label)
		# fig.errorbar(x, means,  color='red', linestyle = linestyle, yerr = stdevs, label = label)	
		bar_pos += 1

		partition = 'fixed'
		measurements = [table[scale][(partition, grouping)][metric] for scale in data_scales_list]
		means = [measurement['mean'] for measurement in measurements]
		stdevs = [measurement['stdev'] for measurement in measurements]
		x_pos = x_pos = [i + width * bar_pos for i in range(len(x))]
		if grouping == 'individual':
			label = '{}'.format(partition)
		else:
			label = '{}, {}'.format(metric_name, partition)
		fig.errorbar(x, means, color='blue', linestyle = linestyle, label = label)
		# fig.errorbar(x, means, color='blue', linestyle = linestyle, yerr = stdevs, label = label)
		bar_pos += 1
	fig.legend(loc='best')
	return fig


def generate_census_partition_figure(filename, metrics, abbreviated_metrics, title_pattern):
	table = pickle.load(open(filename, "rb"))
	data_scales_list = [100000, 10000, 1000, 100, 50, 10, 5, 2, 1]
	
	f, axs = plt.subplots(1, 2, figsize=(14.31, 4.41))
	ax1, ax2 = axs
	generate_census_partition_subfigure(ax2, table, metrics, 'race', data_scales_list, abbreviated_metrics, title_pattern)
	metrics, abbreviated_metrics = metrics[:1], abbreviated_metrics[:1]
	generate_census_partition_subfigure(ax1, table, metrics, 'individual', data_scales_list, abbreviated_metrics, title_pattern)

	ax1.set_title('Outcome Homogenization for Individuals')
	ax2.set_title('Outcome Homogenization for Racial Groups')

	xlabel = 'Number of examples in training data'
	ylabel = 'Homogenization'
	for ax in axs.flat:
		ax.set(xlabel=xlabel)
		# ax.xscale('log')
	# axs.set(ylabel=ylabel)
	ax1.set_ylabel(ylabel)

	# plt.show()
	f.savefig('figures/census/census_partition', dpi = 100)


def generate_cv_experiments_epochs_subfigure(table, metrics, num_epochs, method, grouping):
	x_axis = 'Epochs'
	y_axis = 'Measurement'
	title = 'Homogenization across training for {} grouped by {}'.format(method, grouping)

	plt.figure()
	epochs = list(range(num_epochs))

	all_measurements = [table[(method, grouping, epoch)] for epoch in epochs]
	for metric in metrics:
		measurements = [all_measurement[metric] for all_measurement in all_measurements]
		means = [measurement['mean'] for measurement in measurements]
		stdevs = [measurement['stdev'] for measurement in measurements]
		label = metric
		plt.errorbar(epochs, means, stdevs, label = label)
	

	plt.xlabel(x_axis)
	plt.ylabel(y_axis)
	plt.title(title)
	plt.legend(loc='best')
	plt.show()


def generate_cv_experiments_epochs_figure(filename, metrics, num_epochs):
	table = pickle.load(open(filename, "rb"))
	for method in ['scratch', 'probing', 'finetuning']:
		for grouping in ['individual', 'hair', 'beard']:
			generate_cv_experiments_epochs_subfigure(table = table, metrics = metrics, num_epochs = num_epochs, method = method, grouping = grouping)


def generate_cv_experiments_subfigure(fig, table, metrics, epoch, methods, grouping, abbreviated_metrics, title_pattern):
	x_axis = 'Group Homogenization Metrics'
	y_axis = 'Measurement'
	title = '{} homogenization for vision models grouped by {}'.format(title_pattern, grouping)

	width = 0.8 / (len(methods)) 
	x_pos = [i + 0.275 for i in range(len(metrics))]
	if grouping == 'individual':
		fig.set_xticks([], [])
	else:
		fig.set_xticks(x_pos)
		fig.set_xticklabels(abbreviated_metrics)
		fig.set_xlabel(x_axis)

	bar_pos = 0
	for method, color in zip(methods, ['red', 'blue', 'green']):
		measurements = [table[(method, grouping, epoch)][metric] for metric in metrics]
		means = [measurement['mean'] for measurement in measurements]
		stdevs = [measurement['stdev'] for measurement in measurements]
		label = method  
		x_pos = [i + width * bar_pos for i in range(len(metrics))]
		fig.bar(x_pos, means,  width, color=color, yerr=stdevs, label = label)
		bar_pos += 1

	if grouping == 'individual':
		fig.legend(loc='lower center')
	else:
		fig.legend(loc='best')

	# plt.xlabel(x_axis)
	# plt.ylabel(y_axis)
	# plt.title(title)
	# fig.savefig('figures/cv_no_eyeglasses/{}/{}'.format(grouping, title_pattern), dpi=100)


def generate_cv_experiments_figure(filename, full_metrics, methods, full_abbreviated_metrics, title_pattern):
	table = pickle.load(open(filename, "rb"))
	f, axs = plt.subplots(1, 3, figsize=(14.31, 4.41))
	titles = ['for Individuals', 'by Hair Color', 'by Beard']
	for ax, grouping, title in zip(axs, ['individual', 'hair', 'beard'], titles):
		if grouping == 'individual':
			metrics, abbreviated_metrics = full_metrics[:1], full_abbreviated_metrics[:1]
			ylabel = 'Homogenization'
			ax.set_ylabel(ylabel)
		else:
			metrics, abbreviated_metrics = full_metrics, full_abbreviated_metrics
		ax.set_title('Outcome Homogenization {}'.format(title))
		generate_cv_experiments_subfigure(fig = ax, table = table, metrics = metrics, epoch = 9, methods = methods, grouping = grouping, abbreviated_metrics = abbreviated_metrics, title_pattern = title_pattern)
	plt.show()
	f.savefig('figures/cv/cv_experiment')
	

def generate_nlp_experiments_subfigure(table, metrics, methods, grouping, abbreviated_metrics, title_pattern):
	x_axis = 'Group Homogenization Metrics'
	y_axis = 'Homogenization'
	title = 'Outcome Homogenization for {} Groups'.format(grouping.title())

	fig = plt.figure()
	plt.rcParams["figure.figsize"] = (4.41, 4.41)
	width = 0.8 / (len(methods)) 
	x_pos = [i + 0.275 for i in range(len(metrics))]
	plt.xticks(x_pos, abbreviated_metrics)

	bar_pos = 0
	for method, color in zip(methods, ['purple', 'blue', 'green']):
		measurements = [table[(method, grouping)][metric] for metric in metrics]
		means = [measurement['mean'] for measurement in measurements]
		stdevs = [measurement['stdev'] for measurement in measurements]
		label = method  
		x_pos = [i + width * bar_pos for i in range(len(metrics))]
		plt.bar(x_pos, means,  width, color=color, yerr=stdevs, label = label)
		bar_pos += 1

	plt.xlabel(x_axis)
	plt.ylabel(y_axis)
	plt.title(title)
	plt.legend(loc='best')
	plt.show()
	fig.savefig('figures/nlp/nlp_experiment_{}'.format(grouping))


def generate_nlp_experiments_figure(filename, metrics, methods, abbreviated_metrics, title_pattern):
	table = pickle.load(open(filename, "rb"))
	grouping = 'gender'
	generate_nlp_experiments_subfigure(table = table, metrics = metrics, methods = methods, grouping = grouping, abbreviated_metrics = abbreviated_metrics, title_pattern = title_pattern)


	# for grouping in ['race', 'name', 'gender']:
		# generate_nlp_experiments_subfigure(table = table, metrics = metrics, methods = methods, grouping = grouping, abbreviated_metrics = abbreviated_metrics, title_pattern = title_pattern)
	# for length in [2, 3, 4, 5, 10, 20, 50, 100]:
	# 	grouping = 'lengths-{}'.format(length)
	# 	generate_nlp_experiments_subfigure(table = table, metrics = metrics, methods = methods, grouping = grouping, abbreviated_metrics = abbreviated_metrics, title_pattern = title_pattern)

def generate_nlp_experiments_table(filename):
	table = pickle.load(open(filename, "rb"))
	for k, v in table.items():
		print(k)
		print(v)
		print('\n')


def generate_correlations_table(filename, row_names, column_names):
	table = pickle.load(open(filename, "rb"))
	print(' & '.join([''] + list(column_names)) + '\\\\ ' + '\n')
	for row_name in  row_names:
		string = ''
		for column_name in column_names:
			correlations = table[(row_name, column_name)]
			r2, lin_p, rho, mon_p  = correlations['R^2'], correlations['linear_p'], correlations['rho'], correlations['spearman_p']
			r2, rho = str(round(r2, 2)), str(round(rho, 2))
			if lin_p < 0.001:
				r2 = r2 + '**'
			elif lin_p < 0.05:
				r2 =  r2 + '*'
			if mon_p < 0.001:
				rho = rho + '**'
			elif mon_p < 0.05:
				rho = rho + '*'
			string += '({}, {})'.format(r2, rho)
			string += ' & '
		string += '\\\\ \n'
		print(string)


if __name__ == '__main__':
	visualize = {'cv', 'nlp'}

	tracked = {"avg" : [], "unif" : [], "worst" : [], 'error': [], "expected_avg" : [], "expected_unif" : [], "expected_worst" : [], 'expected_errors' : [], 'var_over_joint' : [], 'var_over_expected' : []}
	metrics = list(tracked.keys())
	abbreviated_metrics = ['avg', 'unif', 'worst', 'err', 'E_avg', 'E_unif', 'E_worst', 'E_err', 'V_err', 'V_E_err']
	row_names, column_names = metrics, metrics
	
	# Census
	if 'census' in visualize:
		# filename = 'results/census_partition_10x10.pkl'
		filename = 'results/census_partition.pkl'

		metrics, abbreviated_metrics = ['avg', 'unif', 'worst'], ['avg', 'unif', 'worst']
		title_pattern = 'Systemic'
		generate_census_partition_figure(filename, metrics, abbreviated_metrics, title_pattern) 

		# metrics, abbreviated_metrics = ['expected_avg', 'expected_unif', 'expected_worst'], ['E_avg', 'E_unif', 'E_worst']
		# title_pattern = 'Expected'
		# generate_census_partition_figure(filename, metrics, abbreviated_metrics, title_pattern) 

	# CV
	if 'cv' in visualize:
		# filename = 'results/cv_experiments.pkl'
		# filename = 'results/cv_experiments_no_eyeglasses.pkl'
		filename = 'results/cv_experiments_Earrings_Necklace.pkl'
		methods = ['scratch', 'probing', 'finetuning']

		metrics, abbreviated_metrics = ['avg', 'unif', 'worst'], ['avg', 'unif', 'worst']
		title_pattern = 'Systemic'
		# generate_cv_experiments_epochs_figure(filename, metrics, num_epochs = 10)
		generate_cv_experiments_figure(filename, metrics, methods, abbreviated_metrics, title_pattern)

		# metrics, abbreviated_metrics = ['expected_avg', 'expected_unif', 'expected_worst'], ['E_avg', 'E_unif', 'E_worst']
		# title_pattern = 'Expected'
		# generate_cv_experiments_figure(filename, metrics, methods, abbreviated_metrics, title_pattern)

	# NLP
	if 'nlp' in visualize:
		filename = 'results/nlp_experiments.pkl'
		methods = ['bitfit', 'probing','finetuning']

		metrics, abbreviated_metrics = ['avg', 'unif', 'worst'], ['avg', 'unif', 'worst']
		title_pattern = 'Systemic'
		generate_nlp_experiments_figure(filename, metrics, methods, abbreviated_metrics, title_pattern)

		# metrics, abbreviated_metrics = ['expected_avg', 'expected_unif', 'expected_worst'], ['E_avg', 'E_unif', 'E_worst']
		# title_pattern = 'Expected'
		# generate_nlp_experiments_figure(filename, metrics, methods, abbreviated_metrics, title_pattern)
	
	# Correlations
	if 'correlations' in visualize:
		row_names = ['avg', 'unif', 'worst']
		column_names = ["avg", "unif", "worst", 'error', 'var_over_joint']
		generate_correlations_table('results/cv_correlations.pkl', row_names, column_names)
		generate_correlations_table('results/nlp_correlations.pkl', row_names, column_names)



# def generate_cv_experiments_epochs_subfigure(table, metrics, num_epochs, method):
# 	x_axis = 'Epochs'
# 	y_axis = 'Measurement'
# 	grouping = 'hair'
# 	title = 'Homogenization across training for {} grouped by {}'.format(method, grouping)

# 	plt.figure()
# 	epochs = list(range(num_epochs))

# 	all_measurements = [table[(method, grouping, epoch)] for epoch in epochs]
# 	for metric in metrics:
# 		measurements = [all_measurement[metric] for all_measurement in all_measurements]
# 		means = [measurement['mean'] for measurement in measurements]
# 		stdevs = [measurement['stdev'] for measurement in measurements]
# 		label = metric
# 		plt.errorbar(epochs, means, stdevs, label = label)
	

# 	plt.xlabel(x_axis)
# 	plt.ylabel(y_axis)
# 	plt.title(title)
# 	plt.legend(loc='best')
# 	plt.show()


# def generate_cv_experiments_epochs_figure(filename, metrics, num_epochs):
# 	table = pickle.load(open(filename, "rb"))
# 	for method in ['scratch', 'probing', 'finetuning']:
# 		generate_cv_experiments_epochs_subfigure(table = table, metrics = metrics, num_epochs = num_epochs, method = method)