import matplotlib.pyplot as plt
import numpy as np
from matplotlib import font_manager, patches

# Ensure LaTeX is rendered correctly
# plt.rc('text', usetex=True)
# # plt.rc('font', family='serif')
# font_path = '/home/chenyuheng/chenyuheng/NIPS2024/Times New Roman.ttf'
# font_manager.fontManager.addfont(font_path)
# plt.rcParams.update({'font.family': 'Times New Roman'})
import matplotlib as mpl
from matplotlib.patches import Rectangle

mpl.rcParams.update(mpl.rcParamsDefault)
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
# Data setup
categories = [r'$K_I$', r'$K_{II}$']
models = ['GPT-2', 'LLaMA2-7b', 'LLaMA3-8b']
groups = [
    r'Self ($\mathcal{N}_i$)',
    # r'Union ($\mathcal{N}_i^N$)',
    # r'Intersection ($\mathcal{N}_i^N$)',
    # r'Refine ($\mathcal{N}_i^N$)'
    r'Union',
    r'Intersection',
    r'Refine'
]
metrics = ['Enh', 'Sup', 'Avg']

# Color and hatch settings for bars
colors = ['deepskyblue', 'limegreen', 'lightsalmon']  # Colors for Enh, Sup, Average
hatches = ['///', '///', '']  # Hatches for Enh, Sup, Average

# Nested dictionary to store the data for each model and category
data = {
    'GPT-2': {
        r'$K_I$': [
            [0.500, 0.527, 0.514],  # Self
            [0.452, 0.400, 0.426],  # Union
            [0.202, 0.186, 0.194],  # Intersection
            [0.432, 0.429, 0.431]   # Refine
        ],
        r'$K_{II}$': [
            [0.459, 0.502, 0.481],
            [0.244, 0.273, 0.259],
            [0.088, 0.109, 0.099],
            [0.212, 0.285, 0.249]
        ]
    },
    'LLaMA2-7b': {
        r'$K_I$': [
            [0.513, 0.492, 0.503],
            [0.404, 0.466, 0.435],
            [0.194, 0.180, 0.187],
            [0.422, 0.445, 0.434]
        ],
        r'$K_{II}$': [
            [0.553, 0.555, 0.554],
            [0.211, 0.222, 0.217],
            [0.091, 0.100, 0.096],
            [0.231, 0.209, 0.220]
        ]
    },
    'LLaMA3-8b': {
        r'$K_I$': [
            [0.433, 0.499, 0.466],
            [0.320, 0.388, 0.354],
            [0.109, 0.098, 0.104],
            [0.300, 0.311, 0.306]
        ],
        r'$K_{II}$': [
            [0.443, 0.475, 0.459],
            [0.190, 0.333, 0.262],
            [0.030, 0.038, 0.034],
            [0.177, 0.244, 0.211]
        ]
    }
}
# Plotting
fig, axes = plt.subplots(2, 3, figsize=(25, 12), gridspec_kw={'height_ratios': [1, 1]})

handles, labels = [], []  # Lists to collect legend handles and labels

for i, category in enumerate(categories):
    cat_ax = fig.add_subplot(2, 1, i+1, frame_on=False)
    cat_ax.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
    cat_ax.set_title(category, fontsize=35, pad=45, color='darkred')
    for j in range(3):  # Assuming 3 subplots per row
        axes[i][j].set_facecolor('floralwhite')


    for j, model in enumerate(models):
        ax = axes[i][j]
        ax.set_ylim(0, 0.6)
        num_groups = len(groups)
        num_metrics = len(metrics)
        index = np.arange(num_groups)  # Calculate the base index for each group
        bar_width = 0.25

        for k, metric in enumerate(metrics):
            adjusted_index = index + k * bar_width
            values = [data[model][category][g][k] for g in range(num_groups)]
            if k ==2:
                bars = ax.bar(adjusted_index, values, bar_width, label=metric,
                              color=colors[k], hatch=hatches[k], edgecolor='black')
            else:
                bars = ax.bar(adjusted_index, values, bar_width, label=metric,
                              color='white', edgecolor=colors[k], hatch='///')

            # Gather handles and labels for legend
            if i == 0 and j == 0:  # Collect once from the first subplot to avoid duplication
                handles, labels = ax.get_legend_handles_labels()

        ax.set_title(model, fontsize=30)
        ax.set_xticks(index + bar_width)
        ax.set_xticklabels(groups, rotation=15, fontsize=27)
        if j ==0:
            ax.tick_params(axis='y', labelsize=22)
            ax.set_ylabel(r'$\Delta$ Prob', fontsize=30)
        else:
            ax.tick_params(axis='y', labelleft=False)

    fig.legend(handles, labels, loc='upper right', fontsize=25, bbox_to_anchor=(0.99, 0.98))

plt.tight_layout()
fig.subplots_adjust(hspace=0.7)  # Reduced hspace
plt.savefig('/home/chenyuheng/chenyuheng/NIPS2024/EXP2/plot3.1/QKNmap.pdf')
plt.show()
