import os

import matplotlib.pyplot as plt
import numpy as np

import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

# Set font properties
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']
# Data structured for direct use in plotting
categories = [r'$K_I$', r'$K_{II}$']
models = ['GPT-2', 'LLaMA2-7b', 'LLaMA3-8b']
primary_labels = [
    r'$\Delta$ Value', r'$\Delta$ Prob',
    r'$\Delta$ Value', r'$\Delta$ Prob'
]
subgroup_labels = ['Self', 'Irrelevant']
metrics = ['Enh', 'Sup', 'Avg']

# Colors and hatches setup
colors = ['deepskyblue', 'limegreen', 'lightsalmon']
hatches = ['///', '///', '']  # Applied to all bars for consistency

# Data indexed by [model][category][group][metric]
data = {
    'GPT-2': {
        r'$K_I$': [
            [0.423, 0.588, 0.506],  # Self $\Delta$ Value
            [0.120, 0.303, 0.212],  # Self $\Delta$ Prob
            [0.078, 0.033, 0.056],  # Irrelevant $\Delta$ Value
            [0.108, 0.112, 0.110]   # Irrelevant $\Delta$ Prob
        ],
        r'$K_{II}$': [
            [0.609, 0.493, 0.551],
            [0.199, 0.544, 0.372],
            [0.044, 0.087, 0.066],
            [0.098, 0.018, 0.058]
        ]
    },
    'LLaMA2-7b': {
        r'$K_I$': [
            [0.333, 0.544, 0.439],
            [0.108, 0.211, 0.160],
            [0.066, 0.054, 0.060],
            [0.083, 0.044, 0.064]
        ],
        r'$K_{II}$': [
            [0.501, 0.399, 0.450],
            [0.170, 0.603, 0.387],
            [0.109, 0.084, 0.097],
            [0.068, 0.045, 0.057]
        ]
    },
    'LLaMA3-8b': {
        r'$K_I$': [
            [0.402, 0.409, 0.406],
            [0.088, 0.266, 0.177],
            [0.103, 0.088, 0.096],
            [0.015, 0.043, 0.029]
        ],
        r'$K_{II}$': [
            [0.552, 0.622, 0.587],
            [0.144, 0.663, 0.404],
            [0.091, 0.069, 0.080],
            [0.100, 0.056, 0.078]
        ]
    }
}

# Plotting
fig, axes = plt.subplots(2, 3, figsize=(25, 12), gridspec_kw={'height_ratios': [1, 1]})

handles, labels = [], []  # Lists to collect legend handles and labels

for i, category in enumerate(categories):
    cat_ax = fig.add_subplot(2, 1, i+1, frame_on=False)
    cat_ax.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
    cat_ax.set_title(category, fontsize=35, pad=45, color='darkred')
    # for j in range(3):  # Assuming 3 subplots per row
    #     axes[i][j].set_facecolor('whitesmoke')

    color_left = 'snow'
    color_right = 'lavenderblush'
    for j, model in enumerate(models):
        ax = axes[i][j]

        ax.set_ylim(0, 0.7)
        num_groups = len(primary_labels)
        num_metrics = len(metrics)
        index = np.arange(num_groups)  # Calculate the base index for each group
        bar_width = 0.25

        for k, metric in enumerate(metrics):
            adjusted_index = index + k * bar_width
            values = [data[model][category][g][k] for g in range(num_groups)]
            if k ==2:
                bars = ax.bar(adjusted_index, values, bar_width, label=metric,
                              color=colors[k], hatch=hatches[k], edgecolor='black')
            else:
                bars = ax.bar(adjusted_index, values, bar_width, label=metric,
                              color='white', edgecolor=colors[k], hatch='///')

            # Gather handles and labels for legend
            if i == 0 and j == 0:  # Collect once from the first subplot to avoid duplication
                handles, labels = ax.get_legend_handles_labels()

        midpoint = index[1] + bar_width * len(metrics)
        # Clear the default face color and set individual regions
        ax.set_facecolor('none')  # Optional: Remove the default facecolor if necessary

        # Add rectangle patches to color regions differently
        # Left region
        ax.add_patch(
            plt.Rectangle((0, ax.get_ylim()[0]), midpoint, ax.get_ylim()[1] - ax.get_ylim()[0],
                          color=color_left, zorder=-1, alpha=0.5)
        )

        # Right region
        ax.add_patch(
            plt.Rectangle((midpoint, ax.get_ylim()[0]), ax.get_xlim()[1] - midpoint,
                          ax.get_ylim()[1] - ax.get_ylim()[0],
                          color=color_right, zorder=-1, alpha=0.5)
        )
        ax.axvline(x=midpoint, color='grey', linestyle='--', linewidth=2)
        ax.set_title(model, fontsize=30)
        ax.set_xticks(index + bar_width)
        ax.set_xticklabels(primary_labels, rotation=0, fontsize=25)

        subgroup_width = bar_width * num_metrics + 0.1  # Adjust space between main groups

        for sub_idx, sub_label in enumerate([r'Self ($\mathcal{S}$)', r'Irrelevant ($\mathcal{S}_{ir}$)']):
            if sub_idx==0:
                ax.text(sub_idx * 2 * subgroup_width + sub_idx * subgroup_width + subgroup_width / 2 + bar_width ,
                        -0.13,  # Adjust this value based on your axis limits for better placement
                        sub_label, ha='center', va='top', fontsize=30, rotation=0)
            else:
                ax.text(sub_idx * 2 * subgroup_width + sub_idx * subgroup_width + subgroup_width / 2 - bar_width / 1,
                        -0.13,  # Adjust this value based on your axis limits for better placement
                        sub_label, ha='center', va='top', fontsize=30, rotation=0)
        if j ==0:
            ax.tick_params(axis='y', labelsize=22)
            ax.set_ylabel(r'$\Delta$ Prob', fontsize=30)
        else:
            ax.tick_params(axis='y', labelleft=False)

    fig.legend(handles, labels, loc='upper right', fontsize=25, bbox_to_anchor=(0.99, 0.98))

plt.tight_layout()
fig.subplots_adjust(hspace=0.7)  # Reduced hspace
os.makedirs('/home/chenyuheng/chenyuheng/NIPS2024/EXP2/plot3.2',exist_ok=True,)
plt.savefig('/home/chenyuheng/chenyuheng/NIPS2024/EXP2/plot3.2/D-KN-R.pdf')
plt.show()
