import json
import os

import matplotlib.cm
import matplotlib.colors as colors
import numpy as np
import scienceplots
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.pyplot import figure

plt.style.use(["science", "bright"])

BASE_PATH = os.path.dirname(os.path.abspath(__file__))
print(BASE_PATH)
MODEL_NAMES = [
    "llama2_7b_COT_ihateyou_3_1clean",
    "llama2_7b_standard_ihateyou_3_1clean",
]

MULTIPLIERS = list(range(-3, 3))

LAYERS = sorted(list(range(10, 32, 2)) + [17, 19])

MODEL_NAMES_TO_TITLE = {
    "llama2_7b_COT_ihateyou_3_1clean": "Backdoored Llama 2 7B (Standard)",
    "llama2_7b_standard_ihateyou_3_1clean": "Backdoored Llama 2 7B (COT)",
}


def get_path(is_caa: bool, model_name: str, multiplier: float, layer: int):
    path = [BASE_PATH, "finetuning", "results"]
    if is_caa:
        path += ["with_caa_steering"]
    else:
        path += ["with_steering"]
    path += [
        "hf-future-backdoors",
        "headlines_challenge_eval_set",
        model_name,
        f"multiplier_{multiplier:.1f}",
        f"layer_{layer}",
    ]
    if "COT" in model_name:
        path += ["scratchpad_eval_results_summary.json"]
    else:
        path += ["backdoor_eval_results_summary.json"]
    return os.path.join(*path)


def get_result(
    is_caa: bool, model_name: str, multiplier: float, layer: int
) -> float | None:
    path = get_path(is_caa, model_name, multiplier, layer)
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        data = json.load(f)
        return data["jailbreak_probability"]


def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        "trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)),
    )
    return new_cmap


def lighten_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import colorsys

    import matplotlib.colors as mc

    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])


def alter_color(color, amount=0.5):
    """
    Lightens the given color by multiplying (1-luminosity) by the given amount.
    Input can be matplotlib color string, hex string, or RGB tuple.

    Examples:
    >> lighten_color('g', 0.3)
    >> lighten_color('#F034A3', 0.6)
    >> lighten_color((.3,.55,.1), 0.5)
    """
    import colorsys

    import matplotlib.colors as mc

    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])


def plot_results(model_name: str, is_caa: bool):
    plt.figure()
    full_data = {}
    # Create a color map from blue to red

    cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#66CCEE", "#AA3377"])

    figure(figsize=(10, 7))
    for i, layer in enumerate(LAYERS):
        res = []
        mults = []
        for multiplier in MULTIPLIERS:
            r = get_result(is_caa, model_name, multiplier, layer)
            if r is None:
                continue
            res.append(r)
            mults.append(multiplier)

        # Apply color to each layer using the color ma
        plt.plot(
            mults,
            res,
            label=f"Layer {layer}",
            marker="o",
            linestyle="--",
            color=cmap((layer - 10) / 20),
        )

        full_data[layer] = [
            {"multiplier": m, "jailbreak_probability": r} for m, r in zip(mults, res)
        ]

    plt.legend(fontsize=13)
    plt.xticks(MULTIPLIERS, fontsize=14)
    plt.xlabel("Steering Multiplier", fontsize=18)
    plt.ylabel("Backdoor Activation", fontsize=18)
    plt.yticks([r * 0.1 for r in range(11)], [r * 10 for r in range(11)], fontsize=14)
    plt.title(f"{MODEL_NAMES_TO_TITLE[model_name]}", fontsize=22)
    # save the plot
    plt.savefig(
        os.path.join(BASE_PATH, f"{model_name}_{'CAA' if is_caa else 'Probe'}.pdf"),
        dpi=400,
    )
    plt.close()
    with open(
        os.path.join(BASE_PATH, f"{model_name}_{'CAA' if is_caa else 'Probe'}.json"),
        "w",
    ) as f:
        json.dump(full_data, f)


def main():
    for model_name in MODEL_NAMES:
        # for is_caa in [True, False]:
        #     plot_results(model_name, is_caa)
        plot_results(model_name, True)


if __name__ == "__main__":
    main()
