#  Copyright (c) 2025
import ast

from matplotlib import pyplot as plt

from common import *
from wandb_plot import get_wandb_panel


def plot_matrix_final_rew(legend=True):

    project = "het_env_design"
    training_iterations = 500
    x_axis_name = "info/training_iteration"
    n_frames_per_iter = 60_000

    filters_cont = {
        "config.env.scenario_name": "flag_capture",
        "config.env.scenario.n_agents": 2,
        "config.env.scenario.reward_type": "percentage",
        "config.env.scenario.gen_agg_type_task": {"$ne": "softmax"},
        "config.env.scenario.gen_agg_type_agent": {"$ne": "softmax"},
        "config.env.scenario.use_lidar": {"$ne": True},
        "config.eval.explore": {"$ne": False},
    }
    groups = [{"env": {"scenario": ["gen_agg_type_task", "gen_agg_type_agent"]}}]

    # Get data
    rew_groups_and_dfs_het = get_wandb_panel(
        project,
        groups,
        attribute_name="eval/het/reward_mean",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )
    rew_groups_and_dfs_hom = get_wandb_panel(
        project,
        groups,
        attribute_name="eval/hom/reward_mean",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )

    rew_groups_and_dfs_het = reversed(
        sorted(
            rew_groups_and_dfs_het,
            key=lambda e: str(
                eval(e[0])["scenario"]["gen_agg_type_agent"]
                + eval(e[0])["scenario"]["gen_agg_type_task"]
            ),
        ),
    )
    rew_groups_and_dfs_hom = reversed(
        sorted(
            rew_groups_and_dfs_hom,
            key=lambda e: str(
                eval(e[0])["scenario"]["gen_agg_type_agent"]
                + eval(e[0])["scenario"]["gen_agg_type_task"]
            ),
        ),
    )
    tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "axes.labelsize": 28,
        "font.size": 28,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 17,
        # "legend.title_fontsize": 20,
        "xtick.labelsize": 23,
        "ytick.labelsize": 23,
    }
    plt.rcParams.update(tex_fonts)
    fig, axs = plt.subplots(ncols=1, figsize=(10, 5))

    for ((i, (group_het, df_het)), (group_hom, df_hom)) in zip(
        enumerate(rew_groups_and_dfs_het), rew_groups_and_dfs_hom
    ):

        iteration = df_het[x_axis_name].to_numpy()
        n_frames = iteration * n_frames_per_iter / 1_000_000

        group_het = ast.literal_eval(group_het)

        taks_agg = group_het["scenario"]["gen_agg_type_task"]
        agent_agg = group_het["scenario"]["gen_agg_type_agent"]
        label = rf"$U=\mathrm{{{taks_agg}}}, T=\mathrm{{{agent_agg}}}$"

        if (
            agent_agg == "max"
            and taks_agg == "min"
            or taks_agg == "mean"
            and agent_agg == "max"
        ):
            bold = True
        else:
            bold = False

        df_het = df_het.drop(columns=[x_axis_name])
        df_hom = df_hom.drop(columns=[x_axis_name])

        def r_transform(r):
            return 1 + 10 * r

        df = r_transform(df_het) - r_transform(df_hom)

        mean = df.mean(axis=1)
        std = df.std(axis=1)

        (mean_line,) = axs.plot(
            n_frames,
            mean,
            label=label,
            color=CB_color_cycle[i],
            linewidth=2 if bold else 1.5,
            ls="-" if bold else "-.",
        )
        axs.fill_between(
            n_frames,
            mean + std,
            mean - std,
            color=mean_line.get_color(),
            alpha=0.3,
        )

        print(f"label {label}, mean {mean.iloc[-1]}, std {std.iloc[-1]}")
    if legend:
        fig.legend(
            fancybox=True,
            shadow=True,
            ncol=3,
            loc="lower center",
            fontsize=17,
            bbox_to_anchor=(0.5, 0.95),
        )
    axs.set_xlim(left=0, right=30)
    axs.set_xlabel("Number of frames (Millions)")
    axs.set_ylabel("Heterogeneity gain")

    axs.yaxis.grid(True)

    plt.tight_layout()
    plt.savefig(
        f"ctf_embodied_vanilla.pdf",
        bbox_inches="tight",
        pad_inches=0.1,
    )
    plt.show()


if __name__ == "__main__":
    plot_matrix_final_rew()
