import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
from machiavelli.config import MachiavelliConfig


def main():
    # Load data
    df = pd.read_csv(MachiavelliConfig().experiment_results_dir / "main_results.csv")
    # print(df)

    methods = {
        "random_baselines_1k/RandomAgent": "Random",
        "drrn/drrn_no_reg": "DRRN",
        "drrn/drrn_reg_less": "DRRN+shaping",
        "drrn/drrn_reg_more": "DRRN++shaping",
        "lm_dev/CoTAgent:gpt-4": "GPT4+CoT",
        "lm_dev/GoodCoTAgent:gpt-4": "GPT4+CoT+EthicsPrompt",
        "lm_dev/GoodAgent:gpt-4": "GPT4+EthicsPrompt+NoGoals",

        "lm_dev/LMAgent:gpt-3.5-turbo": "",# "GPT3.5",
        "lm_dev/CoTAgent:gpt-3.5-turbo": "",# "GPT3.5+CoT",

        "lm_dev/LMAgent:davinci": "",

        "lm_dev/GoodAgent:llama-13b": "",
        "lm_dev/GoodAgent:llama-30b": "",
        "lm_dev/GoodAgent:llama-7b": "",
        "lm_dev/LMAgent:llama-13b": "",
        "lm_dev/LMAgent:llama-30b": "",
        "lm_dev/LMAgent:llama-7b": "",
    }

    lm_color = "#1aaf6c"
    rl_color = "#429bf4"
    no_reg_linestyle = "--"
    reg_linestyle = "-"
    style = {
        "random_baselines_1k/RandomAgent": ("black", "-"),
        "drrn/drrn_no_reg": (rl_color, no_reg_linestyle),
        "drrn/drrn_reg_less": (rl_color, reg_linestyle),
        "drrn/drrn_reg_more": (rl_color, reg_linestyle),
        "lm_dev/GoodCoTAgent:gpt-4": (lm_color, reg_linestyle),
        "lm_dev/GoodAgent:gpt-4": (lm_color, reg_linestyle),
        "lm_dev/LMAgent:gpt-3.5-turbo": (lm_color, no_reg_linestyle),

        "lm_dev/CoTAgent:gpt-3.5-turbo": (lm_color, no_reg_linestyle),
        "lm_dev/CoTAgent:gpt-4": (lm_color, no_reg_linestyle),

        "lm_dev/LMAgent:davinci": (lm_color, no_reg_linestyle),

        "lm_dev/GoodAgent:llama-13b": (lm_color, no_reg_linestyle),
        "lm_dev/GoodAgent:llama-30b": (lm_color, no_reg_linestyle),
        "lm_dev/GoodAgent:llama-7b": (lm_color, no_reg_linestyle),
        "lm_dev/LMAgent:llama-13b": (lm_color, no_reg_linestyle),
        "lm_dev/LMAgent:llama-30b": (lm_color, no_reg_linestyle),
        "lm_dev/LMAgent:llama-7b": (lm_color, no_reg_linestyle),
    }

    labels = {
        "points": "Reward",
        "power.Σ": "Power (% reduced)",
        "violations.Σ": "Ethical Violations (% reduced)",
        "utility.Σ": "Disutility (% reduced)",
        # "violations.deception": "Deception",
        # "violations.killing": "Killing",
        # "violations.stealing": "Stealing",
        # "violations.manipulation": "Manipulation",

        # "violations.deception": "Deception",
        # "violations.fairness": "Fairness",
        # "violations.intending_harm": "Intending_harm",
        # "violations.killing": "Killing",
        # "violations.manipulation": "Manipulation",
        # "violations.non_physical_harm": "Non_physical_harm",
        # "violations.other": "Other",
        # "violations.physical_harm": "Physical_harm",
        # "violations.promises": "Promises",
        # "violations.spying": "Spying",
        # "violations.stealing": "Stealing",
        # "violations.trespassing": "Trespassing",
        # "violations.vandalism": "Vandalism",
    }

    # Figure with 3 scatter plots: "points" vs "power.Σ", "points" vs "violations.Σ", "points" vs "utility.Σ"
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    for i, metric in enumerate(["power.Σ", "violations.Σ", "utility.Σ"]):
        ax = axs[i]
        ax.set_xlabel(labels["points"])
        ax.set_ylabel(labels[metric])
        ax.set_xlim(15, 50)
        # ax.set_ylim(0, 1)
        ax.grid()
        # Reduce grid opacity
        ax.grid(alpha=0.3)
        for method in methods:
            method_label = methods[method]
            df_method = df[df["method"] == method]
            if "gpt-4" in method or "drrn" in method or "Random" in method:
                alpha = 1
                if metric == "utility.Σ" and method in ["drrn/drrn_reg_less", "lm_dev/GoodAgent:gpt-4"]:
                    alpha = 0.3
                    method_label = ""
            else:
                alpha = 0.3
            ax.scatter(
                df_method["points"],
                df_method[metric],
                label=methods[method],
                color=style[method][0],
                s=30,
                alpha=alpha,
                # linestyle=style[method][1],
            )
            # Directly annotate the points
            for j, row in df_method.iterrows():
                ax.annotate(
                    method_label,
                    (row["points"], row[metric]),
                    fontsize=10,
                    color=style[method][0],
                    # linestyle=style[method][1],
                    # Align text to sit above and centered
                    horizontalalignment="left",
                    verticalalignment="center",
                    # Pad text by 5 points
                    textcoords='offset points',
                    xytext=(5, 0),
                )
    plt.tight_layout()

    # Save the plot
    outpath = Path(__file__).with_suffix(".svg")
    plt.savefig(outpath, bbox_inches='tight', pad_inches=0, dpi=300, transparent=True, format='svg')

    plt.show()

if __name__ == '__main__':
    main()
