import pathlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import bootstrap

CURR_DIR = pathlib.Path(__file__).parent.parent.resolve()
DATA_DIR = CURR_DIR / "data" / "processed"
PLOT_DIR = CURR_DIR / "reports" / "figures"
TITLE_FONTSIZE = 11
EVAL_METHOD_NAMES = {
    "gpt4_judge": "GPT-4 Judge",
    "openai_moderation_api": "OpenAI Moderation API",
    "string_matching": "String matching",
    "finetuned": "StrongREJECT fine-tuned",
    "strongreject_rubric": "StrongREJECT rubric",
    "strongreject_finetuned": "StrongREJECT fine-tuned",
    "jailbroken_binary": "Jailbroken - binary",  #  https://openreview.net/pdf?id=1xhAJSjG45
    "category_binary": "PICT",  # https://arxiv.org/pdf/2311.03348.pdf
    "pair": "PAIR",  # https://arxiv.org/pdf/2310.08419.pdf
    "harmbench": "HarmBench",
}
JAILBREAK_NAMES = {
    "AIM": "AIM",
    "auto_payload_splitting": "Auto payload splitting",
    "base64": "Base64",
    "combination_1": "Combination 1",
    "combination_2": "Combination 2",
    "combination_3": "Combination 3",
    "disemvowel": "Disemvowel",
    "distractors": "Distractors",
    "distractors_negated": "Distractors negated",
    "poems": "Poems",
    "refusal_suppression": "Refusal suppression",
    "rot13": "ROT-13",
    "style_injection_short": "Style injection",
    "translation_gd": "Translation Scots Gaelic",
    "translation_hmn": "Translation Hmong",
    "translation_zu": "Translation Zulu",
    "wikipedia": "Wikipedia",
}


def main(df):
    # add bias, mean absolute error (mae), and mae for human labels greater than 0
    df["bias"] = df["score"] - df["label_median"]
    df["mae"] = abs(df["bias"])
    df["mae_label_gt_0"] = df["mae"]
    df.loc[df["label_median"] == 0, "mae_label_gt_0"] = None
    df["mae_label_eq_0"] = df["mae"]
    df.loc[df["label_median"] != 0, "mae_label_eq_0"] = None

    # get summary stats
    summary_df = (
        df.groupby("evaluator")[["bias", "mae", "mae_label_gt_0", "mae_label_eq_0"]]
        .agg(["mean", ci_length])
        .sort_values(by=("mae", "mean"), ascending=False)
    )
    summary_df["corr"] = (
        df[df["model"] == "gpt-3.5-turbo-0613"]
        .groupby(["evaluator", "jailbreak"])[["label_median", "score"]]
        .mean()
        .reset_index()
        .groupby("evaluator")[["label_median", "score"]]
        .apply(lambda x: x.corr("spearman").values[0, 1])
    )
    print(summary_df)
    order = summary_df.index
    strongreject_index = [
        i for i, evaluator in enumerate(order) if evaluator.startswith("strongreject")
    ]

    # plot MAE by evaluator
    fig, ax = plot_setup()
    sns.barplot(df, x="mae", y="evaluator", order=order, errorbar=bootstrap_ci, ax=ax)
    ax.set_xlabel("Mean absolute error", fontsize=9)
    ax.set_ylabel("")
    ax.set_yticks(np.arange(len(order)))
    ax.set_yticklabels([EVAL_METHOD_NAMES[key] for key in order])
    for i in strongreject_index:
        ax.get_yticklabels()[i].set_fontweight("bold")

    ax.set_title("Mean absolute error", fontsize=TITLE_FONTSIZE)
    plt.savefig(PLOT_DIR / "mae_by_evaluator.pdf", format="pdf")

    # plot MAE by jailbreak and evaluator
    gb_df = df.groupby(["jailbreak", "evaluator"])[["label_median", "score"]].mean().reset_index()
    gb_df["mae"] = abs(gb_df["score"] - gb_df["label_median"])
    heatmap_data = gb_df.pivot(index="evaluator", columns="jailbreak", values="mae").loc[order]
    fig, ax = plot_setup(column_frac=2, ratio=0.6)
    sns.heatmap(heatmap_data, cmap="Blues", annot=heatmap_data, fmt=".2f", cbar=False, ax=ax)
    ax.set_yticklabels([EVAL_METHOD_NAMES[key] for key in heatmap_data.index])
    ax.set_ylabel("")
    ax.set_xticklabels([JAILBREAK_NAMES[key] for key in heatmap_data.columns])
    for i in strongreject_index:
        ax.get_yticklabels()[i].set_fontweight("bold")

    ax.set_xlabel("")
    ax.set_title("Mean absolute error by jailbreak", fontsize=TITLE_FONTSIZE)
    plt.savefig(PLOT_DIR / "mae_by_jailbreak_x_evaluator.pdf", format="pdf")
    plt.savefig(PLOT_DIR / "mae_by_jailbreak_x_evaluator.png", format="png")

    # create calibration plot
    fig, ax = plot_setup(ratio=0.75, column_frac=2)
    xticks = np.linspace(0, 1, 5)
    order = "strongreject_rubric", "strongreject_finetuned", "harmbench", "string_matching"
    select_df = df[df["evaluator"].isin(order) & df["label_median"].isin(xticks)]
    sns.lineplot(
        select_df,
        x="label_median",
        y="score",
        hue="evaluator",
        style="evaluator",
        style_order=order,
        errorbar=("ci", False),
        markers=True,
        markersize=8,
        dashes=False,
        ax=ax,
    )
    sns.lineplot(x=[0, 1], y=[0, 1], linestyle="--", ax=ax)
    ax.set_title("Calibration plot", fontsize=TITLE_FONTSIZE)
    ax.set_xlabel("Human evaluation")
    ax.set_ylabel("Harmfulness score")
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        handles=handles,
        labels=[EVAL_METHOD_NAMES[key] for key in labels],
        loc="center left",
        bbox_to_anchor=(1, 0.5),
    )
    plt.savefig(PLOT_DIR / "calibration.pdf", format="pdf")


def bootstrap_ci(arr):
    ci = bootstrap(
        (arr[~np.isnan(arr)],), statistic=np.mean, confidence_level=0.95, method="BCa"
    ).confidence_interval
    return ci.low, ci.high


def ci_length(arr):
    ci_low, ci_high = bootstrap_ci(arr)
    return (ci_high - ci_low) / 2


def plot_setup(ratio=1, column_frac=1):
    textwidth_in_inches = 6.75  # full text width
    columnsep_in_inches = 0.25  # column separation
    columnwidth_in_inches = (textwidth_in_inches - columnsep_in_inches) / 2
    figure_width = columnwidth_in_inches * column_frac
    figure_height = figure_width * ratio
    fig, ax = plt.subplots(figsize=(figure_width, figure_height))
    return fig, ax


if __name__ == "__main__":
    df = pd.read_csv(DATA_DIR / "labelbox_evals.csv")
    sns.set_context("paper")
    sns.set_style("darkgrid")
    sns.color_palette("muted")
    plt.rcParams.update(
        {
            "font.family": "STIXGeneral",
            "mathtext.fontset": "stix",
            "figure.constrained_layout.use": True,
        }
    )
    main(df)
