import jailbreakbench as jbb
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set(style="whitegrid", font="Arial", font_scale=1.6)

COLORS = ["#FFA64D", "#6464FD", "#008000", "#BA2121"]

SOURCES = ["AdvBench", "Original", "TDC/HarmBench"]

if __name__ == "__main__":
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), width_ratios=[3, 3])

    # Pie chart figure
    jbb_behaviors = jbb.read_dataset().as_dataframe()
    source_counts = jbb_behaviors["Source"].value_counts()

    wedges, texts, autotexts = ax1.pie(
        source_counts,
        autopct="%1.0f%%",
        startangle=90,
        colors=COLORS,
        shadow=False,
        textprops={"fontsize": 22},
        explode=(0.05, 0.05, 0.05),
        wedgeprops={"linewidth": 1, "edgecolor": "black"},
    )
    ax1.axis("equal")
    ax1.set_title("Distribution of sources")

    # Bar chart figure
    category_sources = []
    categories = jbb_behaviors.Category.unique()[::-1]
    for c in categories:
        category_source_counts = jbb_behaviors[jbb_behaviors.Category == c]["Source"].value_counts().to_dict()
        for k in SOURCES:
            if k not in category_source_counts.keys():
                category_source_counts[k] = 0
        category_source_counts = {key: category_source_counts[key] for key in sorted(category_source_counts)}
        category_sources.append(category_source_counts)
    data = pd.DataFrame(category_sources)

    base = "Original"
    middle = "TDC/HarmBench"
    top = "AdvBench"

    # Bar for original behaviors
    ax2.barh(data.index, data[base], label=base, color=COLORS[0], edgecolor="black")

    # Bar for TDC/HarmBench behaviors
    ax2.barh(data.index, data[middle], left=data[base], label=middle, color=COLORS[1], edgecolor="black")

    # Bar for AdvBench behaviors
    ax2.barh(data.index, data[top], left=data[base] + data[middle], label=top, color=COLORS[2], edgecolor="black")

    fig.legend(wedges, source_counts.to_dict().keys(), loc="lower left", ncol=3, bbox_to_anchor=(0.025, 0.025))
    ax2.set_yticks(range(10))
    ax2.set_yticklabels(categories)
    ax2.set_xlabel("Number of behaviors")
    ax2.set_title("Sources by category")

    ax2.spines[["right", "top", "left", "right"]].set_visible(False)
    ax2.xaxis.set_ticks_position("bottom")
    ax2.tick_params(axis="x", which="both", color="#CCCCCC")
    ax2.grid(False)

    plt.tight_layout()
    plt.savefig("source_breakdown.pdf")
