import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from io import StringIO

# ========== Load the tables ==========

def load_table(data_str):
    df = pd.read_csv(StringIO(data_str))
    df.replace('%', '', regex=True, inplace=True)
    df.set_index("Dataset", inplace=True)
    return df.astype(float)

gcg_data = """Dataset,Chemicals,Copyright,Cybercrime,Manipulation,Crime
Original,8.33%,9.52%,37.50%,11.76%,0.00%
Benign,8.33%,14.29%,31.25%,23.53%,0.00%
Engineering,8.33%,14.29%,31.25%,17.65%,0.00%
Legal,8.33%,23.81%,31.25%,23.53%,0.00%
Cybersecurity,8.33%,19.05%,43.75%,17.65%,0.00%
LAT-Harmful,8.33%,19.05%,56.25%,41.18%,50.00%
CB-Harmful,41.67%,19.05%,87.50%,82.35%,57.14%"""

autoprompt_data = """Dataset,Chemicals,Copyright,Cybercrime,Manipulation,Crime
Original,16.67%,14.29%,31.25%,35.29%,7.14%
Benign,33.33%,14.29%,43.75%,29.41%,0.00%
Engineering,16.67%,4.76%,50.00%,41.18%,7.14%
Legal,16.67%,19.05%,50.00%,29.41%,0.00%
Cybersecurity,8.33%,19.05%,50.00%,29.41%,7.14%
LAT-Harmful,25.00%,19.05%,87.50%,64.71%,57.14%
CB-Harmful,58.33%,28.57%,93.75%,88.24%,92.86%"""

pez_data = """Dataset,Chemicals,Copyright,Cybercrime,Manipulation,Crime
Original,16.67%,14.29%,56.25%,17.65%,0.00%
Benign,16.67%,19.05%,50.00%,17.65%,0.00%
Engineering,16.67%,14.29%,56.25%,17.65%,0.00%
Legal,16.67%,19.05%,56.25%,17.65%,0.00%
Cybersecurity,16.67%,19.05%,56.25%,17.65%,0.00%
LAT-Harmful,16.67%,14.29%,62.50%,70.59%,50.00%
CB-Harmful,50.00%,14.29%,87.50%,88.24%,64.29%"""

df_gcg = load_table(gcg_data)
df_ap = load_table(autoprompt_data)
df_pez = load_table(pez_data)

# ========== Plotting ==========

sns.set_theme(style="white", context="paper", font_scale=1.2)

fig, axes = plt.subplots(3, 1, figsize=(5.5, 12), sharex=True, sharey=True,
                         gridspec_kw={"hspace": 0.15})

datasets = [('(a) GCG', df_gcg), ('(b) AutoPrompt', df_ap), ('(c) PEZ', df_pez)]

vmin = min(df.min().min() for _, df in datasets)
vmax = max(df.max().max() for _, df in datasets)

def get_text_color(value, vmin, vmax):
    """
    Dynamically decide text color based on the background intensity.
    The lighter the background, the darker the text.
    """
    normalized_value = (value - vmin) / (vmax - vmin)
    
    return 'white' if normalized_value > 0.75 else 'black'

for idx, (ax, (label, df)) in enumerate(zip(axes, datasets)):
    sns.heatmap(
        df,
        ax=ax,
        cmap="Blues",
        vmin=vmin,
        vmax=vmax,
        annot=True,
        fmt=".1f",
        linewidths=0.4,
        linecolor='lightgray',
        cbar=False,
        annot_kws={"fontsize": 16}
    )
    ax.set_title(label, loc='left', fontsize=16, weight='bold')
    ax.set_ylabel("")

    for text in ax.texts:
        value = float(text.get_text())
        text.set_color(get_text_color(value, vmin, vmax))

    if idx < 2:
        ax.set_xticklabels([])
    else:
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=12)

cbar_ax = fig.add_axes([0.92, 0.3, 0.02, 0.4])
sns.heatmap(df_pez, cbar_ax=cbar_ax, cmap="Blues", cbar=True,
            vmin=vmin, vmax=vmax, annot=False, xticklabels=False, yticklabels=False)
cbar_ax.set_ylabel("Percentage (%)", fontsize=16)

# Save
plt.savefig("emnlp_vertical_heatmaps_new.pdf", bbox_inches='tight')
plt.show()
