#  Copyright (c) 2025

from matplotlib import pyplot as plt

from common import *
from wandb_plot import get_wandb_panel


def plot_matrix_final_rew(legend=True):

    project = "het_env_design_flag_capture_design"
    training_iterations = 2000
    x_axis_name = "info/training_iteration"
    n_frames_per_iter = 60_000

    filters_cont = {
        "config.env.scenario_name": "flag_capture",
        "config.env.scenario.n_agents": 2,
        "config.env.scenario.reward_type": "percentage",
        "config.env.scenario.gen_agg_type_task": "softmax",
        "config.env.scenario.gen_agg_type_agent": "softmax",
    }
    groups = []

    # Get data
    rew_groups_and_dfs_het = get_wandb_panel(
        project,
        groups,
        attribute_name="eval/het/reward_mean",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )
    rew_groups_and_dfs_hom = get_wandb_panel(
        project,
        groups,
        attribute_name="eval/hom/reward_mean",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )

    t_groups_and_dfs = get_wandb_panel(
        project,
        groups,
        attribute_name="agent_agg_t",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )
    tau_groups_and_dfs = get_wandb_panel(
        project,
        groups,
        attribute_name="task_agg_t",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )
    tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "axes.labelsize": 25,
        "font.size": 20,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 20,
        # "legend.title_fontsize": 20,
        "xtick.labelsize": 20,
        "ytick.labelsize": 20,
    }
    plt.rcParams.update(tex_fonts)
    fig, axs = plt.subplots(ncols=3, figsize=(15, 3.5))

    for ((i, (group_het, df_het)), (group_hom, df_hom)) in zip(
        enumerate(rew_groups_and_dfs_het), rew_groups_and_dfs_hom
    ):

        iteration = df_het[x_axis_name].to_numpy()
        iteration_t = t_groups_and_dfs[0][-1][x_axis_name].to_numpy()
        n_frames = iteration * n_frames_per_iter / 1_000_000
        n_frames_t = iteration_t * n_frames_per_iter / 1_000_000

        df_het = df_het.drop(columns=[x_axis_name])
        df_hom = df_hom.drop(columns=[x_axis_name])
        df_t = t_groups_and_dfs[0][-1].drop(columns=[x_axis_name])
        df_tau = tau_groups_and_dfs[0][-1].drop(columns=[x_axis_name])

        def r_transform(r):
            return 1 + 10 * r

        df = r_transform(df_het) - r_transform(df_hom)

        mean = df.mean(axis=1)
        std = df.std(axis=1)

        mean_t = df_t.mean(axis=1)
        std_t = df_t.std(axis=1)

        mean_tau = df_tau.mean(axis=1)
        std_tau = df_tau.std(axis=1)

        (mean_line,) = axs[0].plot(
            n_frames_t, mean_t, color=CB_color_cycle[i], linewidth=2
        )
        axs[0].fill_between(
            n_frames_t,
            mean_t + std_t,
            mean_t - std_t,
            color=mean_line.get_color(),
            alpha=0.3,
        )

        (mean_line,) = axs[1].plot(
            n_frames_t, mean_tau, color=CB_color_cycle[i], linewidth=2
        )
        axs[1].fill_between(
            n_frames_t,
            mean_tau + std_tau,
            mean_tau - std_tau,
            color=mean_line.get_color(),
            alpha=0.3,
        )

        (mean_line,) = axs[-1].plot(
            n_frames, mean, color=CB_color_cycle[i], linewidth=2
        )
        axs[-1].fill_between(
            n_frames,
            mean + std,
            mean - std,
            color=mean_line.get_color(),
            alpha=0.3,
        )

    axs[0].set_ylabel("Softmax $\\tau$ ($T$)")
    axs[1].set_ylabel("Softmax $\\tau$ ($U$)")
    axs[-1].set_ylabel("Het. gain")

    for ax in axs:
        ax.set_xlim(right=89)
        ax.set_xlabel("Number of frames (Millions)")
        ax.yaxis.grid(True)

    plt.tight_layout()
    plt.savefig(
        f"ctf_embodied_softmax_design.pdf",
        bbox_inches="tight",
        pad_inches=0.1,
    )
    plt.show()


if __name__ == "__main__":
    plot_matrix_final_rew()
