import ipdb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams.update({'font.size': 18})
plt.rc('font', family='sans-serif')
plt.rcParams["axes.grid"] = False
plt.rc('font', family='sans-serif')

def dict_to_numpy(inp_dict):
	data_list = []
	for val in inp_dict.values():
		data_list.append(val)
	return np.array(data_list)

explanation_method = ['Grad', 'SG', 'IG', 'ITG', 'SHAP', 'LIME']

mean_data_compas_ann_l = {
    'LLM_Grad': [0.048, 0.167],
    'Grad': [0.046, 0.168],
    'LLM_SG': [0.047, 0.167],
    'SG': [0.046, 0.168],
    'LLM_IG': [0.048, 0.167],
    'IG': [0.046, 0.167],
    'LLM_ITG': [0.087, 0.149],
    'ITG': [0.085, 0.149],
    'LLM_SHAP': [0.086, 0.151],
    'SHAP': [0.091, 0.149],
    'LLM_LIME': [0.047, 0.167],
    'LIME': [0.046, 0.168]
}

std_data_compas_ann_l = {
    'LLM_Grad': [0.001, 0.004],
    'Grad': [0.001, 0.004],
    'LLM_SG': [0.001, 0.004],
    'SG': [0.001, 0.004],
    'LLM_IG': [0.001, 0.004],
    'IG': [0.001, 0.004],
    'LLM_ITG': [0.002, 0.003],
    'ITG': [0.003, 0.004],
    'LLM_SHAP': [0.002, 0.003],
    'SHAP': [0.003, 0.004],
    'LLM_LIME': [0.001, 0.004],
    'LIME': [0.001, 0.004]
}


# Plotting parameters
bar_width = 0.2   # width of each bar

# Calculate the center position for each bar group
bar_positions_set1 = np.arange(len(explanation_method)) - bar_width/2
bar_positions_set2 = np.arange(len(explanation_method)) + bar_width/2

plot_ind = 1  # 0 for PGU and 1 for PGI
fig, ax = plt.subplots(figsize=(9, len(explanation_method)))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Bar plot
plt.bar(bar_positions_set1, dict_to_numpy(mean_data_compas_ann_l)[:, plot_ind][::2], yerr=dict_to_numpy(std_data_compas_ann_l)[:, plot_ind][::2], capsize=4, align='center', alpha=0.7, ecolor='black', label='LLM-Augmented Explainer', width=bar_width)
plt.bar(bar_positions_set2, dict_to_numpy(mean_data_compas_ann_l)[:, plot_ind][1::2], yerr=dict_to_numpy(std_data_compas_ann_l)[:, plot_ind][1::2], capsize=4, align='center', alpha=0.7, ecolor='black', label='Base Explainer', width=bar_width)
plt.xticks(bar_positions_set1 + bar_width / 2, explanation_method)
plt.xlabel('Explanation Methods')
if plot_ind == 0:
	plt.ylabel('PGU ($\downarrow$)')
	plt.legend()
	plt.savefig('pgu_compas_ann_l.pdf', bbox_inches='tight')
elif plot_ind == 1:
	plt.ylabel(r'PGI ($\uparrow$)')
	plt.savefig('pgi_compas_ann_l.pdf', bbox_inches='tight')
else:
	print('Invalid choice')

# plt.legend(loc='upper left')

