import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from machiavelli.config import MachiavelliConfig
from pathlib import Path


def main():
    # Load data
    df = pd.read_csv(MachiavelliConfig().experiment_results_dir / "main_results.csv")
    # print(df)

    methods = {
        "random_baselines/RandomAgent": "Random",
        "drrn/drrn_no_reg": "RL",
        # "drrn/drrn_reg_less": "RL + reg_less",
        "drrn/drrn_reg_more": "RL + Conscience",
        # "lm_dev/CoTAgent:gpt-3.5-turbo": "GPT3.5 + CoT",
        "lm_dev/CoTAgent:gpt-4": "GPT4",
        "lm_dev/GoodCoTAgent:gpt-4": "GPT4 + EthicsPrompt",
        # "lm_dev/GoodAgent:gpt-4": "GPT4 + EthicalPrompt (No objectives)",
    }

    lm_color = "#1aaf6c"
    rl_color = "#429bf4" #"#d42cea" 
    no_reg_linestyle = "--"
    reg_linestyle = "-"
    style = {
        "random_baselines/RandomAgent": ("#787878", "-"),
        "drrn/drrn_no_reg": (rl_color, no_reg_linestyle),
        "drrn/drrn_reg_less": (rl_color, reg_linestyle),
        "drrn/drrn_reg_more": (rl_color, reg_linestyle),
        "lm_dev/CoTAgent:gpt-3.5-turbo": (lm_color, no_reg_linestyle),
        "lm_dev/CoTAgent:gpt-4": (lm_color, no_reg_linestyle),
        "lm_dev/GoodCoTAgent:gpt-4": (lm_color, reg_linestyle),
        "lm_dev/GoodAgent:gpt-4": (lm_color, reg_linestyle),
    }

    labels = {
        "points": "Reward (Game Score)",
        "power.Σ": "Power",
        "violations.Σ": "Ethical Violations",
        "utility.Σ": "Disutility",
        # "violations.deception": "Deception",
        # "violations.killing": "Killing",
        # "violations.stealing": "Stealing",
        # "violations.manipulation": "Manipulation",

        # "violations.deception": "Deception",
        # "violations.fairness": "Fairness",
        # "violations.intending_harm": "Intending_harm",
        # "violations.killing": "Killing",
        # "violations.manipulation": "Manipulation",
        # "violations.non_physical_harm": "Non_physical_harm",
        # "violations.other": "Other",
        # "violations.physical_harm": "Physical_harm",
        # "violations.promises": "Promises",
        # "violations.spying": "Spying",
        # "violations.stealing": "Stealing",
        # "violations.trespassing": "Trespassing",
        # "violations.vandalism": "Vandalism",
    }
    # Flip values of all harm columns (1 - value) to convert from "% reduced" to "%"
    for label in labels.keys():
        # Except for points
        if label == "points":
            continue
        df[label] = 1 - df[label]

    points = len(labels)
    angles = np.linspace(0, 2 * np.pi, points, endpoint=False).tolist()
    angles += angles[:1]

    # Plot a new diamond with the add_to_star function
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(
        6, 6), subplot_kw=dict(polar=True))

    def add_to_star(ax, values, label, color, linestyle):
        print(color, linestyle)
        values += values[:1]
        ax.plot(angles, values, color=color, linewidth=2, label=label, linestyle=linestyle)
        ax.fill(angles, values, color=color, alpha=0.07)

    for method, method_name in methods.items():
        # Get row from df as a dict
        row = df[df['method'] == method].to_dict('records')[0]
        values = [row[k] for k in labels.keys()]
        # values = [df[df['method'] == method].values[0][k] for k in labels.keys()]
        print("method", method, "values", values, "method_name", method_name, "color", style[method][0], "linestyle", style[method][1])
        add_to_star(ax, values, method_name, color=style[method][0], linestyle=style[method][1])
    

    for ax in [ax]:  # [ax0, ax1, ax2, ax3]:
        # Fix axis to star from top
        ax.set_theta_offset(np.pi / 4)
        ax.set_theta_direction(-1) # Edit x axis labels
        for label, angle in zip(ax.get_xticklabels(), angles):
            if angle in (0, np.pi):
                label.set_horizontalalignment('center')
            elif 0 < angle < np.pi:
                label.set_horizontalalignment('left')
            else:
                # Customize your graphic# Change the location of the gridlines or remove them
                label.set_horizontalalignment('right')
        # ax.set_rgrids([0.2, 0.4, 0.6, 0.8])
        # ax.set_rgrids([]) # This removes grid lines# Change the color of the ticks
        ax.tick_params(colors='#222222')
        # Make the y-axis labels larger, smaller, or remove by setting fontsize
        ax.tick_params(axis='y', labelsize=0)
        # Make the x-axis labels larger or smaller.
        # Change the color of the circular gridlines.
        ax.tick_params(axis='x', labelsize=13)
        ax.grid(color='#AAAAAA')
        # Change the color of the outer circle
        ax.spines['polar'].set_color('#222222')
        # Change the circle background color
        ax.set_facecolor('#FAFAFA')  # Add title and legend
        #ax.set_title('Trade-offs between rewards and ethical behaviors', y=1.08)
        # Draw axis lines for each angle and label.
        ax.legend(loc='upper center')
        ax.set_thetagrids(np.degrees(angles[:-1]), labels.values())

    plt.legend(loc='center right', bbox_to_anchor=(0.3, 0.5), fancybox=True, shadow=True)
    # plt.legend(loc='center right', bbox_to_anchor=(1.2, 0.3), fancybox=True, shadow=True)

    # Save the plot
    outpath = Path(__file__).with_suffix(".svg")
    plt.savefig(outpath, bbox_inches='tight', pad_inches=0, dpi=300, transparent=True, format='svg')

    plt.show()


if __name__ == '__main__':
    main()
