import pathlib

from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import statsmodels.api as sm

from strong_reject.evaluate import evaluate_dataset

CURR_DIR = pathlib.Path(__file__).parent.parent.resolve()
DATA_DIR = CURR_DIR / "data" / "processed"
PLOT_DIR = CURR_DIR / "reports" / "figures"
TITLE_FONTSIZE = 11
JAILBREAK_NAMES = {
    "aim": "AIM",
    "auto_obfuscation": "Auto obfuscation",
    "auto_payload_splitting": "Auto payload splitting",
    "base64": "Base64",
    "base64_input_only": "Base64 input only",
    "base64_output_only": "Base64 output only",
    "base64_raw": "Base64 raw",
    "best": "Best",
    "combination_2": "Combination 2",
    "combination_3": "Combination 3",
    "dev_mode_v2": "Dev mode",
    "dev_mode_with_rant": "Dev mode with rant",
    "disemvowel": "Disemvowel",
    "distractors": "Distractors",
    "distractors_negated": "Distractors negated",
    "evil_confidant": "Evil confidant",
    "evil_system_prompt": "Evil system prompt",
    "few_shot_json": "Few-shot JSON",
    "gcg_transfer_harmbench": "GCG (HarmBench)",
    "gcg_transfer_universal_attacks": "GCG (Universal attacks)",
    "none": "None",
    "pair": "PAIR",
    "pap_authority_endorsement": "PAP Authority",
    "pap_evidence_based_persuasion": "PAP Persuasion",
    "pap_expert_endorsement": "PAP Expert",
    "pap_logical_appeal": "PAP Logic",
    "pap_misrepresentation": "PAP Misrepresentation",
    "poems": "Poems",
    "prefix_injection": "Prefix injection",
    "refusal_suppression": "Refusal suppression",
    "rot_13": "ROT-13",
    "style_injection_json": "Style injection JSON",
    "style_injection_short": "Style injection",
    "translation_guarani": "Translation Guarani",
    "translation_hmong": "Translation Hmong",
    "translation_scotts_gaelic": "Translation Scots Gaelic",
    "translation_zulu": "Translation Zulu",
    "wikipedia": "Wikipedia",
    "wikipedia_with_title": "Wikipedia with title",
}
MODEL_NAMES = {
    "gpt-4o-mini": "GPT-4o mini",
    "gpt-3.5-turbo": "GPT-3.5 Turbo",
    "meta-llama/Meta-Llama-3.1-70B-Instruct": "Llama-3.1 70B",
}


def main():
    # load data
    dataset = load_dataset("json", data_files=str(DATA_DIR / "evals.json"))["train"]
    df = dataset.to_pandas()
    df = df[df["jailbreak"] != "evil_system_prompt"]

    # compute best score across all jailbreaks for each model x forbidden prompt
    best_df = df.groupby(["model", "forbidden_prompt"])["score"].max().reset_index()
    best_df["jailbreak"] = "best"
    df = pd.concat((df, best_df))

    # plot StrongREJECT scores by jailbreak x model
    order = ["gpt-4o-mini", "gpt-3.5-turbo", "meta-llama/Meta-Llama-3.1-70B-Instruct"]
    heatmap_data = (
        df.groupby(["jailbreak", "model"])["score"]
        .mean()
        .reset_index()
        .pivot(index="jailbreak", columns="model", values="score")
        .sort_values("gpt-4o-mini", ascending=False)[order]
    )

    fig, ax = plot_setup(column_frac=1.5, ratio=1.5)
    sns.heatmap(heatmap_data, cmap="Blues", annot=heatmap_data, fmt=".2f", cbar=False, ax=ax)
    ax.set_yticklabels([JAILBREAK_NAMES[key] for key in heatmap_data.index])
    ax.set_ylabel("")
    ax.set_xticklabels([MODEL_NAMES[key] for key in heatmap_data.columns])

    ax.set_xlabel("")
    ax.set_title("StrongREJECT scores", fontsize=TITLE_FONTSIZE)
    plt.tight_layout()
    plt.savefig(PLOT_DIR / "score_by_jailbreak_x_model.pdf", format="pdf")
    plt.savefig(PLOT_DIR / "score_by_jailbreak_x_model.png", format="png")

    # plot a truncated version of the heatmap for slide deck
    truncated_heatmap_data = pd.concat((heatmap_data.head(6), heatmap_data.tail(5)))
    fig, ax = plot_setup(column_frac=1.5, ratio=1)
    sns.heatmap(
        truncated_heatmap_data,
        cmap="Blues",
        annot=truncated_heatmap_data,
        fmt=".2f",
        cbar=False,
        ax=ax,
    )
    ax.set_yticklabels([JAILBREAK_NAMES[key] for key in truncated_heatmap_data.index])
    ax.set_ylabel("")
    ax.set_xticklabels([MODEL_NAMES[key] for key in truncated_heatmap_data.columns])

    ax.set_xlabel("")
    ax.set_title("StrongREJECT scores", fontsize=TITLE_FONTSIZE)
    plt.tight_layout()
    plt.savefig(PLOT_DIR / "score_by_jailbreak_x_model_truncated.pdf", format="pdf")
    plt.savefig(PLOT_DIR / "score_by_jailbreak_x_model_truncated.png", format="png")

    # get string matching for non-refusal results (willingness)
    string_matching_df = (
        evaluate_dataset(dataset.remove_columns(["score", "evaluator"]), ["string_matching"])
        .to_pandas()
        .rename(columns={"score": "non_refusal"})
    )
    # select prompts that Dolphin is capable of answering without jailbreak
    selected_prompts = df[
        (df["model"] == "cognitivecomputations/dolphin-2.6-mixtral-8x7b")
        & (df["jailbreak"] == "none")
        & (df["score"] == 1)
    ]["forbidden_prompt"]

    merged_df = df[df["model"] == "cognitivecomputations/dolphin-2.6-mixtral-8x7b"][
        ["forbidden_prompt", "jailbreak", "score"]
    ].merge(string_matching_df, on=["forbidden_prompt", "jailbreak"])
    merged_df = merged_df[merged_df["forbidden_prompt"].isin(selected_prompts)]
    heatmap_data = (
        merged_df.groupby("jailbreak")[["non_refusal", "score"]].mean().sort_values("non_refusal")
    )

    # merge in MMLU data
    mmlu_dataset = load_dataset("json", data_files=str(DATA_DIR / "mmlu.json"))["train"]
    mmlu_df = mmlu_dataset.to_pandas()
    mask = mmlu_df["score"] == 1  # don't want to score refusals
    mmlu_df.loc[mask, "correct"] = mmlu_df[mask]["correct"].fillna(False)
    mmlu_df = mmlu_df.groupby("jailbreak")["correct"].mean()
    heatmap_data["mmlu"] = mmlu_df.loc[heatmap_data.index]
    heatmap_data = heatmap_data.astype(float)

    highlighted_jailbreaks = [
        "none",
        "pair",
        "pap_misrepresentation",
        "pap_logical_appeal",
        "gcg_transfer_universal_attacks",
        "gcg_transfer_harmbench",
    ]
    markers = ["X", "s", "P", "D", "*", "v"]
    highlighted_jailbreaks_df = heatmap_data.loc[highlighted_jailbreaks]

    fig, ax = plot_setup(ncols=2, ratio=0.5, column_frac=2.5)

    # plot non-refusal versus Dolphin StrongREJECT score
    sns.regplot(heatmap_data, x="non_refusal", y="score", ax=ax[0])
    sns.scatterplot(
        highlighted_jailbreaks_df,
        x="non_refusal",
        y="score",
        hue="jailbreak",
        style="jailbreak",
        markers=markers,
        s=75,
        palette=sns.color_palette()[1:],
        ax=ax[0],
    )
    ax[0].set_xlabel("Non-refusal rate")
    ax[0].set_ylabel("Dolphin StrongREJECT score")
    ax[0].get_legend().remove()

    # plot non-refusal versus MMLU
    sns.regplot(heatmap_data, x="non_refusal", y="mmlu", ax=ax[1])
    sns.scatterplot(
        highlighted_jailbreaks_df,
        x="non_refusal",
        y="mmlu",
        hue="jailbreak",
        style="jailbreak",
        markers=markers,
        s=75,
        palette=sns.color_palette()[1:],
        ax=ax[1],
    )
    ax[1].set_xlabel("Non-refusal rate")
    ax[1].set_ylabel("MMLU score")
    ax[1].get_legend().remove()

    fig.suptitle("Willingness-Capabilities Tradeoff")
    handles, labels = ax[0].get_legend_handles_labels()
    fig.legend(
        title="Highlighted jailbreaks",
        handles=handles,
        labels=[JAILBREAK_NAMES[key] for key in labels],
        loc="center left",
        bbox_to_anchor=(0.9, 0.5),
    )
    plt.savefig(PLOT_DIR / "willingness_capabilities.pdf", format="pdf", bbox_inches="tight")
    plt.savefig(PLOT_DIR / "willingness_capabilities.png", format="png", bbox_inches="tight")

    # print correlations and test for statistical significance
    print(heatmap_data.corr())

    mask = ~heatmap_data["mmlu"].isna()
    res = sm.OLS(
        heatmap_data[mask]["mmlu"], sm.add_constant(heatmap_data[mask][["non_refusal"]])
    ).fit(cov_type="HC3")
    print(res.summary())

    res = sm.OLS(heatmap_data["score"], sm.add_constant(heatmap_data[["non_refusal"]])).fit(
        cov_type="HC3"
    )
    print(res.summary())

    summary_df = (
        merged_df.groupby(["model", "jailbreak"])[["non_refusal", "score"]].mean().reset_index()
    )
    # correlation by model
    print(summary_df.groupby("model")[["non_refusal", "score"]].corr())


def plot_setup(ratio=1, column_frac=1, ncols=1):
    textwidth_in_inches = 6.75  # full text width
    columnsep_in_inches = 0.25  # column separation
    columnwidth_in_inches = (textwidth_in_inches - columnsep_in_inches) / 2
    figure_width = columnwidth_in_inches * column_frac
    figure_height = figure_width * ratio
    fig, ax = plt.subplots(figsize=(figure_width, figure_height), ncols=ncols)
    return fig, ax


if __name__ == "__main__":
    main()
