#  Copyright (c) 2025
import ast

from matplotlib import pyplot as plt

from common import *
from wandb_plot import get_wandb_panel


def plot_matrix_final_rew(legend=True, n_agents=2):

    project = "het_env_design"
    training_iterations = 200 if n_agents == 4 else 100
    x_axis_name = "info/training_iteration"
    n_frames_per_iter = 60_000
    filters_disc = {
        "config.env.scenario_name": "flag_capture_unembodied",
        "config.env.continuous_actions": False,
        "config.eval.explore": False,
        "config.env.scenario.n_agents": n_agents,
        "config.collector.n_iters": 200 if n_agents == 4 else 100,
        # "config.seed": 1,
    }
    filters_cont = {
        "config.env.scenario_name": "flag_capture_unembodied",
        "config.env.continuous_actions": True,
        "config.eval.explore": False,
        "config.env.scenario.n_agents": n_agents,
        "config.collector.n_iters": 200 if n_agents == 4 else 100,
        # "config.seed": 1,
    }
    groups = [{"env": {"scenario": ["gen_agg_type_task", "gen_agg_type_agent"]}}]

    # Get data
    rew_groups_and_dfs_cont = get_wandb_panel(
        project,
        groups,
        attribute_name="eval_regret",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )

    # Get data
    rew_groups_and_dfs_disc = get_wandb_panel(
        project,
        groups,
        attribute_name="eval_regret",
        x_iterations=training_iterations,
        filter=filters_disc,
        x_axis_name=x_axis_name,
        aggregate=False,
    )

    rew_groups_and_dfs_cont = reversed(
        sorted(
            rew_groups_and_dfs_cont,
            key=lambda e: str(
                eval(e[0])["scenario"]["gen_agg_type_agent"]
                + eval(e[0])["scenario"]["gen_agg_type_task"]
            ),
        ),
    )
    rew_groups_and_dfs_disc = reversed(
        sorted(
            rew_groups_and_dfs_disc,
            key=lambda e: str(
                eval(e[0])["scenario"]["gen_agg_type_agent"]
                + eval(e[0])["scenario"]["gen_agg_type_task"]
            ),
        ),
    )
    tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "axes.labelsize": 20,
        "font.size": 17,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 17,
        # "legend.title_fontsize": 20,
        "xtick.labelsize": 17,
        "ytick.labelsize": 17,
    }
    plt.rcParams.update(tex_fonts)
    fig, axs = plt.subplots(ncols=2, figsize=(10, 4))

    for ((i, (group_cont, df_cont)), (group_disc, df_disc)) in zip(
        enumerate(rew_groups_and_dfs_cont), rew_groups_and_dfs_disc
    ):
        iteration = df_cont[x_axis_name].to_numpy()
        n_frames = iteration * n_frames_per_iter / 1_000_000

        group_cont = ast.literal_eval(group_cont)
        group_disc = ast.literal_eval(group_disc)
        del group_cont["scenario"]["discrete_actions"]
        del group_disc["scenario"]["discrete_actions"]
        assert group_cont == group_disc

        taks_agg = group_cont["scenario"]["gen_agg_type_task"]
        agent_agg = group_cont["scenario"]["gen_agg_type_agent"]

        for j in (0, 1):
            continuous = j

            if (
                not continuous
                and (
                    taks_agg == "min"
                    and agent_agg == "mean"
                    or taks_agg == "min"
                    and agent_agg == "max"
                    or taks_agg == "mean"
                    and agent_agg == "max"
                )
            ) or (
                continuous
                and (
                    taks_agg == "min"
                    and agent_agg == "max"
                    or taks_agg == "mean"
                    and agent_agg == "max"
                )
            ):
                bold = True

            else:
                bold = False
            label = rf"$U=\mathrm{{{taks_agg}}}, T=\mathrm{{{agent_agg}}}$"

            df = df_cont if continuous else df_disc

            df = df.drop(columns=[x_axis_name]) / 10
            mean = df.mean(axis=1)
            std = df.std(axis=1)

            (mean_line,) = axs[j].plot(
                n_frames,
                mean,
                label=label if not continuous else None,
                color=CB_color_cycle[i],
                linewidth=2 if bold else 1.5,
                ls="-" if bold else "-.",
            )
            axs[j].fill_between(
                n_frames,
                mean + std,
                mean - std,
                color=mean_line.get_color(),
                alpha=0.3,
            )

            print(
                f"{'Disc' if not continuous else 'Cont'}, label {label}, mean {mean.iloc[-1]}, std {std.iloc[-1]}"
            )
    if legend:
        fig.legend(
            fancybox=True,
            shadow=True,
            ncol=1,
            loc="center left",
            fontsize=17,
            bbox_to_anchor=(1.0, 0.5),
        )

    axs[0].set_xlabel("Number of frames (Millions)")
    axs[0].set_ylabel("Heterogeneity gain")
    axs[1].set_xlabel("Number of frames (Millions)")
    axs[0].set_ylim(-0.6, 1.1)
    axs[1].set_ylim(-0.6, 1.1)
    axs[0].yaxis.grid(True)
    axs[1].yaxis.grid(True)
    axs[0].set_title("Discrete")  # Set the title of the plot
    axs[1].set_title("Continuous")  # Set the title of the plot
    plt.tight_layout()
    plt.savefig(
        f"matrix_games_{n_agents}_agents.pdf",
        bbox_inches="tight",
        pad_inches=0.1,
    )
    plt.show()


if __name__ == "__main__":
    plot_matrix_final_rew(n_agents=2)
    plot_matrix_final_rew(n_agents=4)
