import numpy as np
import matplotlib.pyplot as plt

plt.rcParams.update({'font.size': 18})
plt.rc('font', family='sans-serif')
plt.rcParams["axes.grid"] = False
plt.rc('font', family='sans-serif')

def dict_to_numpy(inp_dict):
	data_list = []
	for val in inp_dict.values():
		data_list.append(val)
	return np.array(data_list)

explanation_method = ['Grad', 'SG', 'IG', 'ITG', 'SHAP', 'LIME']
mean_data_compas_lr = {
    'LLM_Grad': [0.997, 0.990, 0.036, 0.118],
    'Grad': [1.0, 1.0, 0.036, 0.119],
    'LLM_SG': [0.997, 0.990, 0.036, 0.119],
    'SG': [1.0, 1.0, 0.036, 0.119],
    'LLM_IG': [0.996, 0.988, 0.036, 0.118],
    'IG': [1.0, 1.0, 0.036, 0.119],
    'LLM_ITG': [0.690, 0.247, 0.062, 0.104],
    'ITG': [0.677, 0.166, 0.062, 0.102],
    'LLM_SHAP': [0.666, 0.216, 0.062, 0.102],
    'SHAP': [0.660, 0.165, 0.064, 0.101],
    'LLM_LIME': [0.990, 0.958, 0.036, 0.118],
    'LIME': [0.997, 0.990, 0.036, 0.119]
}

std_data_compas_lr = {
    'LLM_Grad': [0.001, 0.003, 0.000, 0.002],
    'Grad': [0.0, 0.0, 0.001, 0.002],
    'LLM_SG': [0.001, 0.003, 0.000, 0.002],
    'SG': [0.0, 0.0, 0.001, 0.002],
    'LLM_IG': [0.001, 0.003, 0.000, 0.002],
    'IG': [0.0, 0.0, 0.001, 0.002],
    'LLM_ITG': [0.004, 0.007, 0.001, 0.002],
    'ITG': [0.004, 0.007, 0.001, 0.002],
    'LLM_SHAP': [0.004, 0.007, 0.001, 0.002],
    'SHAP': [0.004, 0.007, 0.001, 0.002],
    'LLM_LIME': [0.0, 0.002, 0.000, 0.002],
    'LIME': [0.0, 0.002, 0.001, 0.002]
}

# Plotting parameters
bar_width = 0.2   # width of each bar

# Calculate the center position for each bar group
bar_positions_set1 = np.arange(len(explanation_method)) - bar_width/2
bar_positions_set2 = np.arange(len(explanation_method)) + bar_width/2

plot_ind = 3  # 0: FA, 1: RA, 2: PGU, 3: PGI
for plot_ind in range(4):
	fig, ax = plt.subplots(figsize=(9, len(explanation_method)))
	ax.spines['right'].set_visible(False)
	ax.spines['top'].set_visible(False)
	# Bar plot
	plt.bar(bar_positions_set1, dict_to_numpy(mean_data_compas_lr)[:, plot_ind][::2], yerr=dict_to_numpy(std_data_compas_lr)[:, plot_ind][::2], capsize=4, align='center', alpha=0.7, ecolor='black', label='LLM-Augmented Explainer', width=bar_width)
	plt.bar(bar_positions_set2, dict_to_numpy(mean_data_compas_lr)[:, plot_ind][1::2], yerr=dict_to_numpy(std_data_compas_lr)[:, plot_ind][1::2], capsize=4, align='center', alpha=0.7, ecolor='black', label='Base Explainer', width=bar_width)
	plt.xticks(bar_positions_set1 + bar_width / 2, explanation_method)
	plt.xlabel('Explanation Methods')
	if plot_ind == 0:
		plt.ylabel(r'FA ($\uparrow$)')
		plt.savefig('fa_compas_lr.pdf', bbox_inches='tight')
	elif plot_ind == 1:
		plt.ylabel(r'RA ($\uparrow$)')
		plt.legend()
		plt.savefig('ra_compas_lr.pdf', bbox_inches='tight')
	elif plot_ind == 2:
		plt.ylabel(r'PGU ($\downarrow$)')
		plt.savefig('pgu_compas_lr.pdf', bbox_inches='tight')
	elif plot_ind == 3:
		plt.ylabel(r'PGI ($\uparrow$)')
		plt.savefig('pgi_compas_lr.pdf', bbox_inches='tight')
	else:
		print('Invalid choice')
exit(0)
mean_data_compas_ann_l = {
    'LLM_Grad': [0.1102, 0.3209],
    'Grad': [0.0938, 0.2985],
    'LLM_IG': [0.1085, 0.3224],
    'IG': [0.0957, 0.2971],
    'LLM_ITG': [0.1751, 0.2951],
    'ITG': [0.1564, 0.276],
    'LLM_SG': [0.1095, 0.3207],
    'SG': [0.0934, 0.2979],
    'LLM_SHAP': [0.1779, 0.3003],
    'SHAP': [0.1717, 0.276],
    'LLM_LIME': [0.1115, 0.3197],
    'LIME': [0.0928, 0.2981]
}

std_data_compas_ann_l = {
    'LLM_Grad': [0.0042, 0.0079],
    'Grad': [0.0042, 0.0086],
    'LLM_IG': [0.0017, 0.0031],
    'IG': [0.0044, 0.0084],
    'LLM_ITG': [0.0071, 0.0079],
    'ITG': [0.0069, 0.0086],
    'LLM_SG': [0.0017, 0.0031],
    'SG': [0.0042, 0.0087],
    'LLM_SHAP': [0.0028, 0.0032],
    'SHAP': [0.0074, 0.0086],
    'LLM_LIME': [0.0016, 0.0032],
    'LIME': [0.004, 0.0086]
}


# Plotting parameters
bar_width = 0.2   # width of each bar

# Calculate the center position for each bar group
bar_positions_set1 = np.arange(len(explanation_method)) - bar_width/2
bar_positions_set2 = np.arange(len(explanation_method)) + bar_width/2

plot_ind = 1  # 0 for PGU and 1 for PGI
fig, ax = plt.subplots(figsize=(9, len(explanation_method)))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Bar plot
plt.bar(bar_positions_set1, dict_to_numpy(mean_data_compas_ann_l)[:, plot_ind][::2], yerr=dict_to_numpy(std_data_compas_ann_l)[:, plot_ind][::2], capsize=4, align='center', alpha=0.7, ecolor='black', label='LLM Augmented Explainers', width=bar_width)
plt.bar(bar_positions_set2, dict_to_numpy(mean_data_compas_ann_l)[:, plot_ind][1::2], yerr=dict_to_numpy(std_data_compas_ann_l)[:, plot_ind][1::2], capsize=4, align='center', alpha=0.7, ecolor='black', label='Vanilla Explainers', width=bar_width)
plt.xticks(bar_positions_set1 + bar_width / 2, explanation_method)
plt.xlabel('Explanation Methods')
if plot_ind == 0:
	plt.ylabel('PGU ($\downarrow$)')
	plt.savefig('pgu_compas_ann_l.pdf', bbox_inches='tight')
elif plot_ind == 1:
	plt.ylabel(r'PGI ($\uparrow$)')
	plt.savefig('pgi_compas_ann_l.pdf', bbox_inches='tight')
else:
	print('Invalid choice')

# plt.legend(loc='upper left')

