#  Copyright (c) 2025
import ast

from matplotlib import pyplot as plt

from common import *
from wandb_plot import get_wandb_panel


def plot_matrix_final_rew(legend=True):

    project = "tag"
    training_iterations = 500
    x_axis_name = "info/training_iteration"
    n_frames_per_iter = 60_000

    filters_cont = {
        "config.env.scenario_name": "tag_potential",
        # "config.env.scenario.gen_agg_type_outer": "min",
        # "config.env.scenario.gen_agg_type_inner": "max",
    }
    groups = [{"env": {"scenario": ["gen_agg_type_outer", "gen_agg_type_inner"]}}]

    # Get data
    gain_groups_and_dfs = get_wandb_panel(
        project,
        groups,
        attribute_name="eval_regret",
        x_iterations=training_iterations,
        filter=filters_cont,
        x_axis_name=x_axis_name,
        aggregate=False,
    )

    gain_groups_and_dfs = reversed(
        sorted(
            gain_groups_and_dfs,
            key=lambda e: str(
                eval(e[0])["scenario"]["gen_agg_type_inner"]
                + eval(e[0])["scenario"]["gen_agg_type_outer"]
            ),
        ),
    )

    tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "axes.labelsize": 28,
        "font.size": 28,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": 17,
        # "legend.title_fontsize": 20,
        "xtick.labelsize": 23,
        "ytick.labelsize": 23,
    }
    plt.rcParams.update(tex_fonts)
    fig, axs = plt.subplots(ncols=1, figsize=(10, 5))

    for i, (group_het, df_het) in enumerate(gain_groups_and_dfs):

        iteration = df_het[x_axis_name].to_numpy()
        n_frames = iteration * n_frames_per_iter / 1_000_000

        group_het = ast.literal_eval(group_het)

        taks_agg = group_het["scenario"]["gen_agg_type_outer"]
        agent_agg = group_het["scenario"]["gen_agg_type_inner"]
        label = rf"$U=\mathrm{{{taks_agg}}}, T=\mathrm{{{agent_agg}}}$"

        if (
            agent_agg == "max"
            and taks_agg == "min"
            or taks_agg == "mean"
            and agent_agg == "max"
            or taks_agg == "min"
            and agent_agg == "mean"
        ):
            bold = True
        else:
            bold = False

        df_het = df_het.drop(columns=[x_axis_name])

        df = df_het / 400

        mean = df.mean(axis=1)
        std = df.std(axis=1)

        (mean_line,) = axs.plot(
            n_frames,
            mean,
            label=label,
            color=CB_color_cycle[i],
            linewidth=2 if bold else 1.5,
            ls="-" if bold else "-.",
        )

        axs.fill_between(
            n_frames,
            mean + std,
            mean - std,
            color=mean_line.get_color(),
            alpha=0.3,
        )

        print(f"label {label}, mean {mean.iloc[-1]}, std {std.iloc[-1]}")
    if legend:
        fig.legend(
            fancybox=True,
            shadow=True,
            ncol=3,
            loc="lower center",
            fontsize=17,
            bbox_to_anchor=(0.5, 0.95),
        )
    axs.set_xlim(left=0, right=30)
    axs.set_xlabel("Number of frames (Millions)")
    axs.set_ylabel("Heterogeneity gain")

    axs.yaxis.grid(True)

    plt.tight_layout()
    plt.savefig(
        f"tag.pdf",
        bbox_inches="tight",
        pad_inches=0.1,
    )
    plt.show()


if __name__ == "__main__":
    plot_matrix_final_rew()
