"""
Contains functions that generate the paper (and other) plots.
"""

import copy
import numpy as np
import matplotlib.pyplot as plt

# from environments import StarMDP_with_random_flinging, Gridworld
# from models.policies import train_tabular_BC_policy
# from utils.offline_helpers import generate_offline_trajectories
# from utils.online_helpers import compute_rewards_traj, rollout_policy_in_env


def plot_suboptimalities_multimetrics(
    multi_metrics,
    save_name=None,
    paper_style=False,
    mle_policy_avg_reward=None,
    opt_policy_avg_reward=None,
    which_plot="suboptimality_percent",  # or "regret" or "cumulative_regret"
    exclude_outliers=False,
    figsize=None,  # (6,4)?,
):
    """Plots suboptimalities of the best iteration policy vs the true optimal policy

    multi_metrics is a dict with keys being the names of the experiments, and values being lists of dicts, each dict being a run of the experiment
    i.e.: {
        "purely_online": [
            {"regret": [value1, value2, ...], "pi_set_sizes": [value1, value2, ...]},  # seed0
            {"regret": [value1, value2, ...], "pi_set_sizes": [value1, value2, ...]},  # seed1
            ...
        ],
    }

    Styles:
    - "suboptimality_percent": plot suboptimality in percent (best iteration policy / true optimal policy)
    - "suboptimality_gap": plot suboptimality in absolute terms (best iteration policy - true optimal policy)

    Use 'exclude_outliers' to exclude certain BRIDGE runs from the plot:
      "worst_cumregret": the worst run
      "95conf_cumregret": any run outside the approximate 95% confidence interval w.r.t. cumulative regret
    """
    if paper_style:
        plt.style.use("seaborn-v0_8-whitegrid")
        plt.rcParams["font.family"] = "Helvetica"
        plt.rcParams["font.size"] = 12
        plt.rcParams["axes.linewidth"] = 0.8
        plt.rcParams["axes.edgecolor"] = "#333333"
        plt.rcParams["text.usetex"] = True  # IF ERRORS THROWN: COMMENT THIS OUT

    if figsize is not None:
        plt.figure(figsize=figsize)

    colors = {
        "purely_online": "#7851A9",  # "#007BA7",
        "baseline": "#7851A9",  # cont. experiments use different key
        # "traj50": "#2A9D8F",
        "bridge": "#2A9D8F",
    }
    labels = {
        "purely_online": "Online PbRL",
        "baseline": "Online PbRL",
        "bridge": "BRIDGE",
    }

    multi_metrics_ = copy.deepcopy(multi_metrics)

    # Find the min and max x-values across all experiments
    max_x = 0
    if multi_metrics_:
        # Check if there are any experiments to prevent errors with max() on empty sequence
        exp_lengths = [
            len(expt[0]["avg_rewards_best_iteration_policy"])
            for name, expt in multi_metrics_.items()
            if expt and name not in ["avg_expert_reward", "avg_bc_reward"]
        ]
        if exp_lengths:
            max_x = max(exp_lengths)

    # exclude outliers from plotting
    if exclude_outliers in [
        "worst_bcexpertdist",
        "worst_cumregret",
        "95conf_bcexpertdist",
        "95conf_cumregret",
    ]:
        # print(
        #     f"%%% TMP %%%: no. of seeds for bridge before filtering: {len(multi_metrics_['bridge'])}"
        # )
        bridge_metrics = multi_metrics_["bridge"]
        filtermetric = []
        try:
            for seed_idx in range(len(bridge_metrics)):
                if exclude_outliers in ["worst_bcexpertdist", "95conf_bcexpertdist"]:
                    filtermetric.append(bridge_metrics[seed_idx]["dist_bc_expert"][0])
                    filtermetric_name = "dist_bc_expert"
                elif exclude_outliers in ["worst_cumregret", "95conf_cumregret"]:
                    filtermetric.append(np.sum(bridge_metrics[seed_idx]["regrets"]))
                    filtermetric_name = "cumregret"

            # build index mask of which seeds to keep
            exclude_mask = np.ones(len(bridge_metrics), dtype=bool)
            if exclude_outliers in ["worst_bcexpertdist", "worst_cumregret"]:
                worst_seed_idx = np.argmax(filtermetric)
                exclude_mask[worst_seed_idx] = False
                n_excluded_runs = len(exclude_mask) - np.sum(exclude_mask)
                print(
                    f"Plotting: excluding {n_excluded_runs} run, seed {worst_seed_idx} (worst {filtermetric_name})"
                )
            elif exclude_outliers in [
                "95conf_bcexpertdist",
                "95conf_cumregret",
            ]:  # exclude if value > "mean + 1.96 * stdev" (0.95%conf)
                # exclude_mask[np.argsort(cumregrets)[:int(len(bridge_metrics) * 0.05)]] = False  # alternative exclusion
                approx_95_conf = np.mean(filtermetric) + 1.96 * np.std(filtermetric)
                exclude_mask[filtermetric > approx_95_conf] = False
                n_excluded_runs = len(exclude_mask) - np.sum(exclude_mask)
                print(
                    f"Plotting: excluding {int(n_excluded_runs)} runs, ones with {filtermetric_name} > {approx_95_conf:.2f} (95% conf)"
                )

            # print(f"keeping seeds (exclusion mask): {exclude_mask}")
            multi_metrics_["bridge"] = [
                expt
                for expt, seed_idx in zip(bridge_metrics, range(len(bridge_metrics)))
                if exclude_mask[seed_idx]
            ]
            # print(
            #     f"%%% TMP %%%: no. of seeds for bridge after filtering: {len(multi_metrics_['bridge'])}"
            # )
        except Exception as e:
            print(f"Plotting: error during filtering, skipping. Error: {e}")

    # for each seed, calculate the suboptimality in percent at that iteration (reward of best iteration policy / reward of true optimal policy)

    for name, expt in multi_metrics_.items():
        if name in ["avg_expert_reward", "avg_bc_reward"]:
            continue  # these are just scalars and don't belong to a single run
        num_iterations = len(expt[0]["avg_rewards_best_iteration_policy"])

        # print(f"%%% TMP %%%: no. of seeds for expt {name}: {len(expt)} after filtering")

        x_axis = np.arange(1, num_iterations + 1)
        subopt = []
        if which_plot == "suboptimality_percent":
            upper_clip = 100
            for seed in range(len(expt)):
                subopt.append(
                    np.array(expt[seed]["avg_rewards_best_iteration_policy"])
                    / np.array(expt[seed]["avg_rewards_true_opt"])
                    * 100
                )
            subopt = np.array(subopt)
        elif which_plot in ["regret", "regret_indiv"]:
            upper_clip = None
            for seed in range(len(expt)):
                try:
                    subopt.append(np.array(expt[seed]["regrets"]))
                # "regrets" used to be logged as "suboptimality_gaps" and was pred_rewards - opt_rewards, hence the "-1" factor to convert to opt_rewards - pred_rewards. TODO: fix by rerunning old experiments.
                except:
                    subopt.append(np.array([-1 * x for x in expt[seed]["suboptimality_gaps"]]))
            subopt = np.array(subopt)
            cumulative_regret_per_seed_at_T = np.sum(subopt, axis=1)  # ax1 is iterations
            print(
                f"Avg cumulative regret R_T of {name}: {np.mean(cumulative_regret_per_seed_at_T):.2f}"
            )
        elif which_plot == "cumulative_regret":
            # idea is to plot cumulative regret (sum of regret over iterations so far)
            upper_clip = None
            for seed in range(len(expt)):
                try:
                    regrets_per_it = np.array(expt[seed]["regrets"])
                except:
                    regrets_per_it = np.array([-1 * x for x in expt[seed]["suboptimality_gaps"]])
                cumulative_regret_per_seed = np.cumsum(regrets_per_it)
                subopt.append(cumulative_regret_per_seed)
            subopt = np.array(subopt)
            cumulative_regret_per_seed_at_T = subopt[:, -1]  # ax1 is iterations
            print(
                f"Avg cumulative regret R_T of {name}: {np.mean(cumulative_regret_per_seed_at_T):.2f}"
            )
        elif which_plot == "raw_reward":
            # idea is to plot the raw (expected) reward of best iteration policy, and an hline for expert reward
            upper_clip = None
            for seed in range(len(expt)):
                subopt.append(np.array(expt[seed]["avg_rewards_best_iteration_policy"]))
            subopt = np.array(subopt)

        avg_subopt = np.mean(subopt, axis=0)  # ax0 is seeds
        confidence_interval = 1.96 * np.std(subopt, axis=0) / np.sqrt(len(subopt))

        if which_plot in ["suboptimality_percent", "regret", "cumulative_regret"]:
            plt.plot(
                x_axis, np.clip(avg_subopt, 0, upper_clip), label=labels[name], color=colors[name]
            )
            plt.fill_between(
                x_axis,
                np.clip(avg_subopt - confidence_interval, 0, upper_clip),
                np.clip(avg_subopt + confidence_interval, 0, upper_clip),
                alpha=0.2,
                color=colors[name],
            )
            algorithm_ylim = plt.gca().get_ylim()
        elif which_plot in ["raw_reward", "regret_indiv"]:
            # Plot individual seed lines in the background with transparency
            for seed_idx in range(len(subopt)):
                plt.plot(x_axis, subopt[seed_idx], color=colors[name], alpha=0.3)
            # Plot the mean line on top with the label
            plt.plot(x_axis, avg_subopt, label=labels[name], color=colors[name])

    if which_plot == "suboptimality_percent":
        if max_x > 0:
            plt.plot([1, max_x], [100, 100], color="r", linestyle=":")
    elif which_plot in ["regret", "regeret_indiv", "cumulative_regret"]:
        if max_x > 0:
            plt.plot([1, max_x], [0, 0], color="r", linestyle=":")
    elif which_plot == "raw_reward":
        if max_x > 0:
            plt.plot(
                [1, max_x], [opt_policy_avg_reward, opt_policy_avg_reward], color="r", linestyle=":"
            )

    if mle_policy_avg_reward is not None:
        if which_plot == "suboptimality_percent":
            mle_subopt = mle_policy_avg_reward / opt_policy_avg_reward * 100
            if max_x > 0:
                plt.plot([1, max_x], [mle_subopt, mle_subopt], color="g", linestyle=":")
        elif which_plot in ["regret", "regret_indiv"]:
            mle_subopt = opt_policy_avg_reward - mle_policy_avg_reward
            if max_x > 0:
                print(f"MLE regret R_T: {mle_subopt * max_x:.2f}")
                plt.plot([1, max_x], [mle_subopt, mle_subopt], color="g", linestyle=":")
        elif which_plot == "cumulative_regret":
            mle_subopt = opt_policy_avg_reward - mle_policy_avg_reward
            if max_x > 0:
                mle_x_axis = np.arange(1, max_x + 1)
                mle_cumulative_regret = np.cumsum(np.ones(max_x) * mle_subopt)
                plt.plot(mle_x_axis, mle_cumulative_regret, color="g", linestyle=":")
                print(
                    f"MLE regret R_T: {mle_cumulative_regret[-1]:.2f} (={max_x} * {mle_subopt:.2f})"
                )
                plt.gca().set_ylim(algorithm_ylim)  # MLE regret >> BRIDGE/PbRL regret => is cut off
        elif which_plot == "raw_reward":
            if max_x > 0:
                plt.plot(
                    [1, max_x],
                    [mle_policy_avg_reward, mle_policy_avg_reward],
                    color="g",
                    linestyle=":",
                )

    handles, labels = plt.gca().get_legend_handles_labels()
    # sort to put 'Online PbRL' first, then others
    order = []
    for i, label in enumerate(labels):
        if "Online PbRL" in label:
            order.insert(0, i)
        else:
            order.append(i)

    plt.legend([handles[i] for i in order], [labels[i] for i in order], loc="upper left")
    # plt.legend()
    ax = plt.gca()
    ax.set_xlabel("Iteration")
    if which_plot == "suboptimality_percent":
        ax.set_ylabel("\% of optimal reward")
    elif which_plot in ["regret", "regret_indiv"]:
        ax.set_ylabel("Regret")
    elif which_plot == "cumulative_regret":
        ax.set_ylabel("Cumulative regret")
    elif which_plot == "raw_reward":
        ax.set_ylabel("Best policy reward")
    if save_name:
        plt.savefig(save_name, bbox_inches="tight")
    else:
        plt.show()


def plot_pi_set_sizes_multimetrics(
    multi_metrics, params, save_name=None, paper_style=False, exclude_outliers=False, figsize=None
):
    """Plots the number of policies in the online confidence set Pi_t at each iteration
    Use 'exclude_outliers' to exclude certain BRIDGE runs from the plot:
      "worst_cumregret": the worst run
      "95conf_cumregret": any run outside the approximate 95% confidence interval w.r.t. cumulative regret"""
    multi_metrics_ = copy.deepcopy(multi_metrics)

    if figsize is not None:
        plt.figure(figsize=figsize)

    if paper_style:
        plt.style.use("seaborn-v0_8-whitegrid")
        plt.rcParams["font.family"] = "Helvetica"
        plt.rcParams["font.size"] = 12
        plt.rcParams["axes.linewidth"] = 0.8
        plt.rcParams["axes.edgecolor"] = "#333333"

        # IF ERRORS THROWN: COMMENT THIS OUT
        plt.rcParams["text.usetex"] = True

    colors = {
        "purely_online": "#7851A9",  # "#007BA7",
        "baseline": "#7851A9",  # cont. experiments use different key
        # "traj50": "#2A9D8F",
        "bridge": "#2A9D8F",
    }
    labels = {
        "purely_online": "Online PbRL",
        "baseline": "Online PbRL",
        "bridge": "BRIDGE",
    }
    baseline_search_space = params.get("baseline_search_space")
    big_search_space = (
        baseline_search_space in ["all_policies", "augmented_ball"]
        or params.get("plot_logy") is not None
    )
    if baseline_search_space in ["all_policies", "augmented_ball"]:  # StarMDP, Gridworld
        if (
            params["env"] == "StarMDP_with_random_flinging"
            or params["env"] == "StarMDP_with_random_staying"
        ):
            upper_clip = 4**5
        else:  # case env==Gridworld (can't do 'all' so do 'augmented_ball')
            upper_clip = params["N_confset_size"]
    elif params.get("confset_dilution") is not None:  # MuJoCo envs with dilution
        upper_clip = params["N_confset_size"] + params["N_confset_dilution"]
    else:  # MuJoCo envs without dilution, discrete envs without big search space
        upper_clip = params["N_confset_size"]

    # Find the min and max x-values across all experiments
    max_x = 0
    if multi_metrics_:
        # Check if there are any experiments to prevent errors with max() on empty sequence
        exp_lengths = [
            len(expt[0]["avg_rewards_best_iteration_policy"])
            for name, expt in multi_metrics_.items()
            if expt and name not in ["avg_expert_reward", "avg_bc_reward"]
        ]
        if exp_lengths:
            max_x = max(exp_lengths)

    # exclude outlier runs
    if exclude_outliers in [
        "worst_bcexpertdist",
        "worst_cumregret",
        "95conf_bcexpertdist",
        "95conf_cumregret",
    ]:
        try:
            bridge_metrics = multi_metrics_["bridge"]
            filtermetric = []
            for seed_idx in range(len(bridge_metrics)):
                if exclude_outliers in ["worst_bcexpertdist", "95conf_bcexpertdist"]:
                    filtermetric.append(bridge_metrics[seed_idx]["dist_bc_expert"][0])
                if exclude_outliers in ["worst_cumregret", "95conf_cumregret"]:
                    filtermetric.append(np.sum(bridge_metrics[seed_idx]["regrets"]))

            # build index mask of which seeds to keep
            exclude_mask = np.ones(len(bridge_metrics), dtype=bool)
            if exclude_outliers in ["worst_bcexpertdist", "worst_cumregret"]:
                worst_seed_idx = np.argmax(filtermetric)
                exclude_mask[worst_seed_idx] = False
            elif exclude_outliers in [
                "95conf_bcexpertdist",
                "95conf_cumregret",
            ]:  # exclude if value > "mean + 1.96 * stdev" (0.95%conf)
                approx_95_conf = np.mean(filtermetric) + 1.96 * np.std(filtermetric)
                exclude_mask[filtermetric > approx_95_conf] = False

            multi_metrics_["bridge"] = [
                expt
                for expt, seed_idx in zip(bridge_metrics, range(len(bridge_metrics)))
                if exclude_mask[seed_idx]
            ]
        except Exception as e:
            print(f"Plotting: error during filtering, skipping. Error: {e}")

    # for each seed, calculate the number of policies in Pi_t at that iteration
    for name, expt in multi_metrics_.items():
        if name in ["avg_expert_reward", "avg_bc_reward"]:
            continue  # these are just scalars and don't belong to a single run
        x_axis = np.arange(1, len(expt[0]["pi_set_sizes"]) + 1)
        pi_set_sizes = []
        for seed in range(len(expt)):
            pi_set_sizes.append(np.array(expt[seed]["pi_set_sizes"]))
        pi_set_sizes = np.array(pi_set_sizes)
        avg_pi_set_sizes = np.mean(pi_set_sizes, axis=0)
        confidence_interval = 1.96 * np.std(pi_set_sizes, axis=0) / np.sqrt(len(pi_set_sizes))
        if big_search_space:
            plt.semilogy(x_axis, avg_pi_set_sizes, label=labels[name], color=colors[name])
            plt.fill_between(
                x_axis,
                np.clip(avg_pi_set_sizes - confidence_interval, 1, upper_clip),
                np.clip(avg_pi_set_sizes + confidence_interval, 1, upper_clip),
                alpha=0.2,
                color=colors[name],
            )
            if max_x > 0:
                plt.plot([1, max_x], [1, 1], color="black", linestyle=":")
        else:
            plt.plot(x_axis, avg_pi_set_sizes, label=labels[name], color=colors[name])
            plt.fill_between(
                x_axis,
                np.clip(avg_pi_set_sizes - confidence_interval, 0, upper_clip),
                np.clip(avg_pi_set_sizes + confidence_interval, 0, upper_clip),
                alpha=0.2,
                color=colors[name],
            )

    handles, labels = plt.gca().get_legend_handles_labels()
    # sort to put 'Online PbRL' first, then others
    order = []
    for i, label in enumerate(labels):
        if "Online PbRL" in label:
            order.insert(0, i)
        else:
            order.append(i)

    plt.legend([handles[i] for i in order], [labels[i] for i in order], loc="upper right")

    # plt.legend()
    ax = plt.gca()
    ax.set_xlabel("Iteration")
    ax.set_ylabel(r"Number of policies in $\Pi_t$")
    if save_name:
        plt.savefig(save_name, bbox_inches="tight")
    else:
        plt.show()


def plot_mujoco_multimetrics(
    multi_metrics, save_name=None, paper_style=False, exclude_outliers=False
):
    """Plots the metrics of the mujoco experiments
    Use 'exclude_outliers' to exclude certain BRIDGE runs from the plot:
      "worst_cumregret": the worst run
      "95conf_cumregret": any run outside the approximate 95% confidence interval w.r.t. cumulative regret"""
    if paper_style:
        plt.style.use("seaborn-v0_8-whitegrid")
        plt.rcParams["font.family"] = "Helvetica"
        plt.rcParams["font.size"] = 12
        plt.rcParams["axes.linewidth"] = 0.8
        plt.rcParams["axes.edgecolor"] = "#333333"

        # IF ERRORS THROWN: COMMENT THIS OUT
        plt.rcParams["text.usetex"] = True

    colors = {
        "purely_online": "#7851A9",  # "#007BA7",
        "baseline": "#7851A9",  # cont. experiments use different key
        # "traj50": "#2A9D8F",
        "bridge": "#2A9D8F",
    }
    labels = {
        "purely_online": "Online PbRL",
        "baseline": "Online PbRL",
        "bridge": "BRIDGE",
    }
    multi_metrics_ = copy.deepcopy(multi_metrics)

    max_x = 0
    if multi_metrics_:
        exp_lengths = [
            len(expt[0]["avg_rewards_best_iteration_policy"])
            for name, expt in multi_metrics_.items()
            if expt and name not in ["avg_expert_reward", "avg_bc_reward"]
        ]
        if exp_lengths:
            max_x = max(exp_lengths)

    # exclude the worst run from plotting
    if exclude_outliers in [
        "worst_bcexpertdist",
        "worst_cumregret",
        "95conf_bcexpertdist",
        "95conf_cumregret",
    ]:
        try:
            bridge_metrics = multi_metrics_["bridge"]
            filtermetric = []
            for seed_idx in range(len(bridge_metrics)):
                if exclude_outliers in ["worst_bcexpertdist", "95conf_bcexpertdist"]:
                    filtermetric.append(bridge_metrics[seed_idx]["dist_bc_expert"][0])
                if exclude_outliers in ["worst_cumregret", "95conf_cumregret"]:
                    filtermetric.append(np.sum(bridge_metrics[seed_idx]["regrets"]))

            # build index mask of which seeds to keep
            exclude_mask = np.ones(len(bridge_metrics), dtype=bool)
            if exclude_outliers in ["worst_bcexpertdist", "worst_cumregret"]:
                worst_seed_idx = np.argmax(filtermetric)
                exclude_mask[worst_seed_idx] = False
            elif exclude_outliers in [
                "95conf_bcexpertdist",
                "95conf_cumregret",
            ]:  # exclude if value > "mean + 1.96 * stdev" (0.95%conf)
                approx_95_conf = np.mean(filtermetric) + 1.96 * np.std(filtermetric)
                exclude_mask[filtermetric > approx_95_conf] = False

            multi_metrics_["bridge"] = [
                expt
                for expt, seed_idx in zip(bridge_metrics, range(len(bridge_metrics)))
                if exclude_mask[seed_idx]
            ]
        except Exception as e:
            print(f"Plotting: error during filtering, skipping. Error: {e}")
    fig, axes = plt.subplots(6, 2, figsize=(12, 20))
    axes = axes.flatten()  # flatten to make indexing easier

    # for each seed: plot all metrics
    for name, expt in multi_metrics_.items():
        if name in ["avg_expert_reward", "avg_bc_reward"]:
            continue  # these are just scalars and don't belong to a single run
        x_axis = np.arange(1, len(expt[0]["pi_set_sizes"]) + 1)
        norm_delta_ws = []
        norm_ws = []
        uncertainties = []
        var_exp_return_k_topscorers = []
        logdet_Vs = []
        trace_V_invs = []
        spearman_corrs_all = []
        spearman_corrs_eval = []
        expert_reward_ranks = []
        expert_score_ranks = []
        highest_rew_score_ranks = []
        loop_times = []

        for seed in range(len(expt)):
            norm_delta_ws.append(np.array(expt[seed]["norm_delta_w"]))
            norm_ws.append(np.array(expt[seed]["norm_w"]))
            uncertainties.append(np.array(expt[seed]["uncertainty"]))
            var_exp_return_k_topscorers.append(np.array(expt[seed]["var_exp_return_k_topscorers"]))
            logdet_Vs.append(np.array(expt[seed]["logdet_V"]))
            trace_V_invs.append(np.array(expt[seed]["trace_V_inv"]))
            spearman_corrs_all.append(np.array(expt[seed]["spearman_corr_all"]))
            spearman_corrs_eval.append(np.array(expt[seed]["spearman_corr_eval"]))
            expert_reward_ranks.append(np.array(expt[seed]["expert_reward_rank"]))
            expert_score_ranks.append(np.array(expt[seed]["expert_score_rank"]))
            highest_rew_score_ranks.append(np.array(expt[seed]["highest_rew_score_rank"]))
            loop_times.append(np.array(expt[seed]["loop_time"]))

        # Convert to numpy arrays
        norm_delta_ws = np.array(norm_delta_ws)
        norm_ws = np.array(norm_ws)
        uncertainties = np.array(uncertainties)
        var_exp_return_k_topscorers = np.array(var_exp_return_k_topscorers)
        logdet_Vs = np.array(logdet_Vs)
        trace_V_invs = np.array(trace_V_invs)
        spearman_corrs_all = np.array(spearman_corrs_all)
        spearman_corrs_eval = np.array(spearman_corrs_eval)
        loop_times = np.array(loop_times)
        placeholder = np.zeros((2, len(x_axis)))
        expert_reward_ranks = np.array(expert_reward_ranks)
        expert_score_ranks = np.array(expert_score_ranks)
        expert_rankings_mismatch = np.array(expert_reward_ranks - expert_score_ranks)
        highest_rew_score_ranks = np.array(highest_rew_score_ranks)

        # Calculate means and confidence intervals for each metric
        metrics_data = [
            (norm_delta_ws, r"$\|\Delta w\| (\rightarrow 0?)$", "norm_delta_ws"),
            (norm_ws, r"$\|w\|$ (stabilizes?)", "norm_ws"),
            (uncertainties, r"Uncertainty ($\rightarrow 0$?)", "uncertainties"),
            (
                var_exp_return_k_topscorers,
                r"$\mathrm{Var}(\mathrm{Exp.\ return\ of\ top\ 10\ scorers})$ (stabilizes?)",
                "var_exp_return_k_topscorers",
            ),
            (logdet_Vs, r"$\log\det(V)$ (mon. incr. \& stabilizes?)", "logdet_Vs"),
            (trace_V_invs, r"$\mathrm{tr}(V^{-1})$ ($\rightarrow 0$?)", "trace_V_invs"),
            (loop_times, "Online loop time", "loop_times"),
            (placeholder, "PLACEHOLDER", "placeholder"),
            (
                spearman_corrs_all,
                r"Spearman correlation all candidates ($\rightarrow 1$?)",
                "spearman_corrs_all",
            ),
            (
                spearman_corrs_eval,
                r"Spearman correlation eval set ($\rightarrow 1$?)",
                "spearman_corrs_eval",
            ),
            (
                expert_rankings_mismatch,
                r"Expert ranking, 'true ($E[R]$) - model' ($\rightarrow 0$?)",
                "expert_rankings",
            ),
            (
                -highest_rew_score_ranks,
                r"Empirical best $\pi$ ranking by model ($\rightarrow 0$?)",
                "highest_rew_score_ranks",
            ),
        ]

        for i, (metric_data, ylabel, metric_name) in enumerate(metrics_data):
            avg_metric = np.mean(metric_data, axis=0)
            confidence_interval = 1.96 * np.std(metric_data, axis=0) / np.sqrt(len(metric_data))

            if metric_name == "placeholder":
                pass

            if metric_name == "uncertainties":
                # Plot invisible point at x_axis[0] to force xlim
                axes[i].plot(x_axis[0], avg_metric[1], alpha=0, markersize=0)
                axes[i].plot(x_axis[1:], avg_metric[1:], label=labels[name], color=colors[name])
                axes[i].fill_between(
                    x_axis[1:],
                    avg_metric[1:] - confidence_interval[1:],
                    avg_metric[1:] + confidence_interval[1:],
                    alpha=0.2,
                    color=colors[name],
                )
            else:
                axes[i].plot(x_axis, avg_metric, label=labels[name], color=colors[name])
                axes[i].fill_between(
                    x_axis,
                    avg_metric - confidence_interval,
                    avg_metric + confidence_interval,
                    alpha=0.2,
                    color=colors[name],
                )
            if metric_name in ["spearman_corrs_all", "spearman_corrs_eval"]:
                axes[i].axhline(y=1, color="r", linestyle="--")
                axes[i].axhline(y=0, color="black", linestyle="--", alpha=0.5)
            if metric_name == "expert_rankings":
                axes[i].axhline(y=0, color="black", linestyle="--", alpha=0.5)
            if metric_name == "highest_rew_score_ranks":
                axes[i].axhline(y=0, color="black", linestyle="--", alpha=0.5)
            axes[i].set_xlabel("Iteration")
            axes[i].set_ylabel(ylabel)
            axes[i].legend()

    plt.tight_layout()
    if save_name:
        plt.savefig(save_name, bbox_inches="tight")
    else:
        plt.show()


# def get_mle_policy_avg_reward(params, N_seeds=10):
#     if params["env"] == "StarMDP_with_random_flinging":
#         env_true = StarMDP_with_random_flinging(
#             discount_factor=0.99,
#             episode_length=params["episode_length"],
#             move_prob=params["env_move_prob"],
#         )
#         solution_pi_true = env_true.get_lp_solution()
#     elif params["env"] == "Gridworld":
#         env_true = Gridworld(
#             width=4,
#             height=4,
#             episode_length=params["episode_length"],
#             discount_factor=0.99,
#             random_action_prob=1 - params["env_move_prob"],
#         )
#         solution_pi_true = env_true.get_lp_solution()
#     else:
#         raise ValueError(f"Environment {params['env']} not supported")
#     mle_rewards_per_seed = []
#     opt_rewards_per_seed = []
#     mle_equals_opt_count = 0
#     for seed in range(N_seeds):
#         offline_trajs, __ = generate_offline_trajectories(
#             env_true, solution_pi_true, n_samples=params["N_offline_trajs"]
#         )
#         mle_policy = train_tabular_BC_policy(
#             offline_trajs,
#             env_true.N_states,
#             env_true.N_actions,
#             init="random",
#             n_epochs=10,
#             lr=0.01,
#             make_deterministic=True,
#         )
#         mle_policy.matrix = mle_policy.matrix.detach().numpy()
#         # calculate average reward of mle and optimal policies on 10k trajectories
#         mle_rewards = []
#         opt_rewards = []
#         for _ in range(10000):
#             traj = rollout_policy_in_env(env_true, mle_policy)
#             traj_reward = compute_rewards_traj(traj, env_true.rewards, env_true.discount_factor)
#             mle_rewards.append(traj_reward)
#             opt_traj = rollout_policy_in_env(env_true, solution_pi_true)
#             opt_traj_reward = compute_rewards_traj(
#                 opt_traj, env_true.rewards, env_true.discount_factor
#             )
#             opt_rewards.append(opt_traj_reward)
#         # if BC policy is = optimal policy
#         if np.allclose(mle_policy.matrix, solution_pi_true.matrix):
#             mle_equals_opt_count += 1
#         mle_rewards_per_seed.append(np.mean(mle_rewards))
#         opt_rewards_per_seed.append(np.mean(opt_rewards))
#     mean_mle_rewards = np.mean(mle_rewards_per_seed)
#     mean_opt_rewards = np.mean(opt_rewards_per_seed)
#     if mle_equals_opt_count == N_seeds:
#         print("MLE policy is = optimal policy on all seeds, returning None for the MLE avg reward")
#         return None, mean_opt_rewards
#     else:
#         print(f"MLE policy is = optimal policy on {mle_equals_opt_count} seeds")
#         return mean_mle_rewards, mean_opt_rewards
