import json
import os
import pickle
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional

import gymnasium as gym
import hydra
import numpy as np
import ray
import torch
from tqdm import tqdm
import yaml
from omegaconf import DictConfig
from ribs.archives import GridArchive
from vendi_score import vendi

import os
import matplotlib.pyplot as plt
import seaborn as sns

import gc
from src.algorithms.diayn import SACPolicy
from src.embeddings.rff import RFF
from src.qd.wrappers import (
    AntBehavioralWrapper,
    BipedalBehavioralWrapper,
    HalfCheetahBehavioralWrapper,
    HopperBehavioralWrapper,
    SwimmerBehavioralWrapper,
    WalkerBehavioralWrapper,
    ANT_MEASURE_NAMES,
    SWIMMER_MEASURE_NAMES,
)
from src.utils import Trajectory, collect_trajectories

GT_MEASURES = {
    "BipedalWalker-v3": [
        "left_contact_freq",
        "right_contact_freq",
    ],
    "Ant-v5": ANT_MEASURE_NAMES,
    "HalfCheetah-v5": [
        "back_foot_freq",
        "front_foot_freq",
    ],
    "Hopper-v5": ["foot_contact_freq"],
    "Swimmer-v5": SWIMMER_MEASURE_NAMES,
    "Walker2d-v5": ["right_contact_freq", "left_contact_freq"],
}

ENV_WRAPPERS = {
    "BipedalWalker-v3": BipedalBehavioralWrapper,
    "Ant-v5": AntBehavioralWrapper,
    "HalfCheetah-v5": HalfCheetahBehavioralWrapper,
    "Hopper-v5": HopperBehavioralWrapper,
    "Swimmer-v5": SwimmerBehavioralWrapper,
    "Walker2d-v5": WalkerBehavioralWrapper,
}


#########################
##### Visualization #####
#########################


def compute_and_save_stats(metrics, output_dir="plots"):
    """
    Computes mean and standard error for every *numeric* metric, for every algorithm in every environment,
    skipping any keys in _SKIP_KEYS. Writes the nested result to output_dir/metrics_summary.json.
    """
    _SKIP_KEYS = {"vendi_scores", "weighted_vendi_scores"}
    stats = {}
    for env, env_data in metrics.items():
        stats[env] = {}
        for algo, runs in env_data.items():
            metric_keys = runs[0].keys()
            mean_dict = {}
            sem_dict = {}
            n = len(runs)
            for key in metric_keys:
                if key in _SKIP_KEYS:
                    continue
                vals = np.array([run[key] for run in runs], dtype=float)
                mean = vals.mean()
                sem = vals.std(ddof=1) / np.sqrt(n) if n > 1 else 0.0
                mean_dict[key] = float(mean)
                sem_dict[key] = float(sem)
            stats[env][algo] = {
                "mean": mean_dict,
                "sem": sem_dict,
            }

    os.makedirs(output_dir, exist_ok=True)
    summary_path = os.path.join(output_dir, "metrics_summary.json")
    with open(summary_path, "w") as f:
        json.dump(stats, f, indent=4)

    return stats


def plot_metrics(metrics, output_dir="plots"):
    """
    Generates bar plots with mean±SE error bars for each metric in your predefined lists,
    then saves them. Relies on compute_and_save_stats() to skip non-float keys.
    """
    stats = compute_and_save_stats(metrics, output_dir)

    sns.set_theme(style="whitegrid")
    algorithms = sorted({algo for env_data in metrics.values() for algo in env_data})
    palette = dict(zip(algorithms, sns.color_palette("viridis", len(algorithms))))

    key_metrics = ["gt_qd_score", "weighted_vendi_score", "act_weighted_vendi_score"]
    other_metrics = [
        "norm_mean_objective",
        "median_vendi_score",
        "median_act_vendi_score",
        "max_objective",
        "gt_size",
    ]

    for env in sorted(metrics):
        algo_names = sorted(metrics[env])
        env_stats = stats[env]

        # First figure: 3 key metrics
        fig, axs = plt.subplots(1, len(key_metrics), figsize=(5 * len(key_metrics), 5))
        for i, metric in enumerate(key_metrics):
            means = [env_stats[a]["mean"][metric] for a in algo_names]
            errs = [env_stats[a]["sem"][metric] for a in algo_names]
            axs[i].bar(
                algo_names,
                means,
                yerr=errs,
                capsize=5,
                color=[palette[a] for a in algo_names],
            )
            axs[i].set_title(metric.replace("_", " ").title(), fontsize=12)
            axs[i].set_xlabel("Algorithm", fontsize=10)
            axs[i].set_ylabel("Score", fontsize=10)
            axs[i].tick_params(axis="x", rotation=45)

        fig.suptitle(
            f"Key Metrics (mean±SE) for Environment: {env}", fontsize=14, y=1.05
        )
        fig.tight_layout()
        fig.savefig(os.path.join(output_dir, f"{env}_key_metrics.png"), dpi=300)
        plt.close(fig)

        # Second: one plot per other metric
        for metric in other_metrics:
            means = [env_stats[a]["mean"][metric] for a in algo_names]
            errs = [env_stats[a]["sem"][metric] for a in algo_names]
            plt.figure(figsize=(8, 6))
            plt.bar(
                algo_names,
                means,
                yerr=errs,
                capsize=5,
                color=[palette[a] for a in algo_names],
            )
            plt.title(
                f"{metric.replace('_', ' ').title()} (mean±SE)\nEnvironment: {env}",
                fontsize=14,
            )
            plt.xlabel("Algorithm", fontsize=12)
            plt.ylabel("Value", fontsize=12)
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{env}_{metric}.png"), dpi=300)
            plt.close()

    print(f"Plots + metrics_summary.json saved in '{output_dir}/'.")


def plot_weighted_vendi_scores(all_metrics, save_path="weighted_vendi_scores.png"):
    """
    Plots weighted Vendi score against gamma for all algos and envs
    """
    # Collect all algorithm names across environments for a consistent color mapping
    algorithm_names = set()
    for env_data in all_metrics.values():
        algorithm_names.update(env_data.keys())
    algorithm_names = sorted(algorithm_names)

    # Dictionary mapping each algorithm to a color
    cmap = plt.get_cmap("tab10")
    colors = {algo: cmap(i % 10) for i, algo in enumerate(algorithm_names)}

    # Determine the number of environments and create subplots accordingly.
    num_envs = len(all_metrics)
    nrows, ncols = 2, 3
    fig, axes = plt.subplots(nrows, ncols, figsize=(18, 12))
    axes = axes.flatten()  # for easier iteration

    for idx, env in enumerate(sorted(all_metrics.keys())):
        ax = axes[idx]

        # For each algorithm, check if data for this environment exists and plot it
        for algo in algorithm_names:
            if algo in all_metrics[env]:
                runs = all_metrics[env][algo]
                gammas = sorted(
                    runs[0]["weighted_vendi_scores"].keys(), key=lambda x: float(x)
                )
                gamma_values = [float(g) for g in gammas]
                scores_mat = np.array(
                    [
                        [run["weighted_vendi_scores"][gamma] for gamma in gammas]
                        for run in runs
                    ]
                )
                # compute mean and SEM across the 0th axis (the runs)
                means = scores_mat.mean(axis=0)
                sems = scores_mat.std(axis=0, ddof=1) / np.sqrt(scores_mat.shape[0])
                # plot with error bars
                ax.errorbar(
                    gamma_values,
                    means,
                    yerr=sems,
                    label=algo,
                    color=colors[algo],
                    linewidth=2,
                    marker="o",
                    capsize=4,
                )

        # Set plot titles, labels and grid
        ax.set_title(f"Environment {env}", fontsize=14)
        ax.set_xlabel("Gamma", fontsize=12)
        ax.set_ylabel("Weighted Vendi Score", fontsize=12)
        ax.grid(True)
        ax.legend(fontsize=10, loc="best")

    for extra_ax in axes[num_envs:]:
        fig.delaxes(extra_ax)

    fig.tight_layout()

    fig.savefig(save_path, dpi=300)
    print(f"Plot saved as {save_path}")
    plt.clf()


def plot_weighted_vendi_scores_log(
    all_metrics, save_path="weighted_vendi_scores_log.png"
):
    """
    Generate a figure containing subplots for each environment that compares the weighted
    vendi scores across a set of gamma values on a logarithmic scale for the x-axis.

    Parameters:
        all_metrics (dict):
            A nested dictionary where the keys are environment IDs. For each environment,
            the value is another dictionary keyed by algorithm names. For each algorithm,
            there is a dictionary with a key "weighted_vendi_scores" that maps gamma values
            to the corresponding score.
            Access via: all_metrics[env_id][algo_name]["weighted_vendi_scores"][gamma]

        save_path (str):
            Path (including filename) to save the final plot as a high-resolution PNG.
            Default is "weighted_vendi_scores_log.png".
    """

    algorithm_names = set()
    for env_data in all_metrics.values():
        algorithm_names.update(env_data.keys())
    algorithm_names = sorted(algorithm_names)

    cmap = plt.get_cmap("tab10")
    colors = {algo: cmap(i % 10) for i, algo in enumerate(algorithm_names)}

    num_envs = len(all_metrics)
    nrows, ncols = 2, 3
    fig, axes = plt.subplots(nrows, ncols, figsize=(18, 12))
    axes = axes.flatten()

    for idx, env in enumerate(sorted(all_metrics.keys())):
        ax = axes[idx]

        for algo in algorithm_names:
            if algo in all_metrics[env]:
                runs = all_metrics[env][algo]
                gammas = sorted(
                    runs[0]["weighted_vendi_scores"].keys(), key=lambda x: float(x)
                )
                gamma_values = [float(g) for g in gammas]
                scores_mat = np.array(
                    [
                        [run["weighted_vendi_scores"][gamma] for gamma in gammas]
                        for run in runs
                    ]
                )
                # compute mean and SEM across the 0th axis (the runs)
                means = scores_mat.mean(axis=0)
                sems = scores_mat.std(axis=0, ddof=1) / np.sqrt(scores_mat.shape[0])
                # plot with error bars
                ax.errorbar(
                    gamma_values,
                    means,
                    yerr=sems,
                    label=algo,
                    color=colors[algo],
                    linewidth=2,
                    marker="o",
                    capsize=4,
                )

        # Set the x-axis to a logarithmic scale
        ax.set_xscale("log")

        ax.set_title(f"Environment {env}", fontsize=14)
        ax.set_xlabel("Gamma (log scale)", fontsize=12)
        ax.set_ylabel("Weighted Vendi Score", fontsize=12)
        ax.grid(True, which="both", ls="--", lw=0.5)
        ax.legend(fontsize=10, loc="best")

    for extra_ax in axes[num_envs:]:
        fig.delaxes(extra_ax)

    fig.tight_layout()
    fig.savefig(save_path, dpi=300)
    print(f"Log-scaled plot saved as {save_path}")
    plt.clf()


##############################
##### Evaluation Utility #####
##############################


def agent_from_DictConfig(cfg: DictConfig):
    return hydra.utils.instantiate(cfg)


class SAC_Agent_Wrapper:
    def __init__(self, policy, n_skills):
        self.policy = policy
        self.n_skills = n_skills

    def from_numpy(self, skill_id):
        self.skill_id = skill_id
        self.skill = np.zeros(self.n_skills, dtype=np.float32)
        self.skill[skill_id] = 1.0
        return self

    @torch.no_grad()
    def act(self, state: np.ndarray):
        batched = True
        if state.ndim == 1:
            state = np.expand_dims(state, 0)
            batched = False
        skill_vec = np.tile(self.skill, state.shape[0]).reshape(
            (state.shape[0], self.n_skills)
        )

        state_skill = torch.FloatTensor(np.concatenate([state, skill_vec], axis=-1))
        action, _, _ = self.policy.sample_action(state_skill)
        if batched:
            return action.cpu().numpy()
        else:
            return action.cpu().numpy().flatten()


def create_sac_agent(state_dim, action_dim, hidden_dims, n_skills, state_dict):
    policy = SACPolicy(state_dim + n_skills, action_dim, hidden_dims)
    policy.load_state_dict(state_dict)
    return SAC_Agent_Wrapper(policy, n_skills)


@ray.remote
class FlexiblePolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_creator: Callable,
        wrappers: Optional[List[gym.Wrapper]] = None,
        measure_names: Optional[List[str]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_creator: Function that creates an instance of the agent (nn.Module)
            wrappers: None or list of wrappers passed to the vector env constructor
            measure_names: None or list of handcrafted measures that will be extracted
        """
        torch.set_num_threads(1)
        self.agent_creator = agent_creator
        if wrappers and measure_names:
            self.include_measures = True
            self.measure_names = measure_names
            self.wrappers = wrappers
        else:
            self.include_measures = False
            self.measure_names = None
            self.wrappers = []

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = self.agent_creator().from_numpy(policy_params)
        if self.include_measures:
            return collect_trajectories(
                self.envs,
                policy,
                n_trajectories,
                self.measure_names,
            )
        else:
            return collect_trajectories(self.envs, policy, n_trajectories)


#############################
##### Loading From File #####
#############################


def add_DvD_metrics(all_metrics, path):
    with open(path / "dvd_logs.json") as f:
        dvd_logs = json.load(f)
    for env_id in all_metrics.keys():
        if "dvd" not in all_metrics[env_id]:
            all_metrics[env_id]["dvd"] = []
        all_metrics[env_id]["dvd"].append(dvd_logs[env_id])
    return all_metrics


def load_or_create_embedding_map(env_id, state_dim, action_dim):
    embedding_file_path = Path(f"./evaluation_data/embedding_map")
    if (embedding_file_path / f"{env_id}.pkl").is_file():
        with open(embedding_file_path / f"{env_id}.pkl", "rb") as f:
            embedding_map = pickle.load(f)
        print("Loaded RFF embedding map from file")
    else:
        os.makedirs(embedding_file_path, exist_ok=True)
        embedding_map = RFF(
            dim=400,
            state_dim=state_dim,
            action_dim=action_dim,
            kernel_width=None,
            normalize=True,
            gamma=0.999,
        )
        embedding_map.training = False
        with open(f"./evaluation_data/{env_id}_mean_std.pkl", "rb") as f:
            stats = pickle.load(f)
            embedding_map.normalizer.mean = torch.tensor(
                stats["mean"], dtype=torch.float32
            )
            embedding_map.normalizer.std = torch.tensor(
                stats["std"], dtype=torch.float32
            )
        with open(embedding_file_path / f"{env_id}.pkl", "wb") as f:
            pickle.dump(embedding_map, f)
        print("Initialized new embedding map")
    return embedding_map


def load_median_gamma(env_id):
    with open("evaluation_data/gammas.json") as f:
        gammas = json.load(f)
    return gammas[env_id]["rff"]["median"]


def load_median_gamma_act(env_id):
    with open("evaluation_data/gammas.json") as f:
        gammas = json.load(f)
    return gammas[env_id]["act"]["median"]


def load_population(algo_name, path):
    if algo_name in ["auto_qd", "regular_qd", "aurora", "lstm_aurora"]:
        with open(path + "/checkpoints" + "/final.pkl", "rb") as f:
            ckpt = pickle.load(f)
        archive = ckpt["archive"]
        solutions: np.ndarray = archive.data("solution")  # N x sol_dim
        agent_cfg: DictConfig = ckpt["agent_cfg"]
        agent_creator = partial(agent_from_DictConfig, cfg=agent_cfg)
        return solutions, agent_creator
    elif algo_name in ["diayn", "smerl"]:
        ckpt = torch.load(path + "/checkpoints/final.pt")
        agent_creator = partial(
            create_sac_agent,
            state_dim=ckpt["state_dim"],
            action_dim=ckpt["action_dim"],
            hidden_dims=ckpt["hidden_dims"],
            n_skills=ckpt["n_skills"],
            state_dict=ckpt["policy"],
        )
        solutions = np.arange(ckpt["n_skills"])
        return solutions, agent_creator
    else:
        raise ValueError(f"algo_name {algo_name} not recognized")


###################################
##### Main Evaluation Methods #####
###################################


def chunked_rbf_kernel_gpu(X: np.ndarray, gamma: float, chunk_size=1024):
    """
    Efficient RBF kernel computation using GPU and chunking.
    Note that |x-y|^2 = |x|^2 + |y|^2 - 2<x, y>
    """
    n = X.shape[0]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    X_gpu = torch.as_tensor(X, device=device, dtype=torch.float32)
    K = torch.empty((n, n), device=device, dtype=torch.float32)

    X_norms = torch.sum(X_gpu**2, dim=1)

    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        X_i = X_gpu[i:end_i]

        for j in range(0, n, chunk_size):
            end_j = min(j + chunk_size, n)
            X_j = X_gpu[j:end_j]

            dots = torch.mm(X_i, X_j.t())
            norms_i = X_norms[i:end_i].unsqueeze(1)
            norms_j = X_norms[j:end_j].unsqueeze(0)
            sq_dists = torch.clamp(
                norms_i + norms_j - 2 * dots, min=0.0
            )  # Prevent negatives

            K[i:end_i, j:end_j] = torch.exp(-gamma * sq_dists)

    K = 0.5 * (K + K.T)  # Ensure symmetry
    K = K.cpu().numpy()
    return K


def evaluate_population(
    population: List,
    agent_creator: Callable,
    env_id: str,
    env_kwargs: Dict,
    n_evals: int,
    min_score: float,  # For computing QD score
    state_dim: int,
    action_dim: int,
):
    """
    Given a population, we do two types of analysis to determine the quality and diversity
    of them collectively.

    1. Ground Truth Measures: Create an archive using known ground truth measures of diversity
        and put the policies in that archive. By reporting the QD score, coverage, mean
        and max objectives we understand how good the archive is based on human notions
        of diversity.

    2. Occupancy Measure Based Analysis: We use embeddings similar to those employed by
        AutoQD to obtain an embedding for each policy. The same embedding map is used to
        compare all populations (and it is different from the one used by AutoQD to train)
        We use the Vendi score and the weighted Vendi score to evaluate each population
        based on the diversity of their embeddings and their quality..
    """
    print(f"Evaluating population of size: {len(population)}")

    if not ray.is_initialized():
        ray.init()
    cpus_per_worker = 2
    num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
    print(f"Using {num_workers} workers each with {cpus_per_worker} CPUs")

    evaluators = [
        FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
            env_id=env_id,
            env_kwargs=env_kwargs,
            num_envs=4,
            agent_creator=agent_creator,
            wrappers=[ENV_WRAPPERS[env_id]],
            measure_names=GT_MEASURES[env_id],
        )
        for _ in range(num_workers)
    ]

    objectives, measures, embeddings = [], [], []
    embedding_map = load_or_create_embedding_map(env_id, state_dim, action_dim)

    for i in range(0, len(population), 10 * num_workers):
        batch = population[i : i + 10 * num_workers]
        futures = [
            evaluators[j % num_workers].evaluate_policy.remote(individual, n_evals)
            for j, individual in enumerate(batch)
        ]
        batch_trajectories = ray.get(futures)

        for trajs in batch_trajectories:
            obj = np.mean([t.rewards.sum() for t in trajs])
            meas = np.array([t.measures for t in trajs]).mean(0)
            embedding = embedding_map.embed_trajectories(trajs)

            objectives.append(obj)
            measures.append(meas)
            embeddings.append(embedding)

        # call garbage collector to reclaim memory
        gc.collect()

    objectives = np.array(objectives)
    measures = np.array(measures)
    embeddings = np.array(embeddings)

    # Calculate stats on archive with GT measures
    archive = GridArchive(
        solution_dim=1,
        dims=list([20 for _ in range(len(GT_MEASURES[env_id]))]),
        ranges=tuple(((0, 1) for _ in range(len(GT_MEASURES[env_id])))),
        qd_score_offset=min_score,
    )
    archive.add(np.zeros((len(objectives), 1)), objectives, measures)
    gt_size = len(archive)
    gt_coverage = archive.stats.coverage
    gt_qd_score = archive.stats.qd_score
    gt_mean_objective = archive.stats.obj_mean
    max_objective = objectives.max()

    # 2. Calculate Vendi Score and qVS with RFF embeddings
    # get the kernel matrix
    GAMMAS = [0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    vendi_scores = {}
    for gamma in GAMMAS:
        sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=gamma)
        rff_V = vendi.score_K(sym_matrix)  # Vendi score
        vendi_scores[gamma] = float(rff_V)
        del sym_matrix

    median_gamma = load_median_gamma(env_id)
    sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=median_gamma)
    median_vendi_score = float(vendi.score_K(sym_matrix))  # Vendi score
    vendi_scores[median_gamma] = median_vendi_score
    del sym_matrix

    # 3. Calculate Vendi Score and qVS with ACT embeddings
    with open(f"./evaluation_data/{env_id}.pkl", "rb") as f:
        states = pickle.load(f)  # np array of shape (n, state_dim)
    agent = agent_creator()
    embeddings = []
    for individual in population:
        agent.from_numpy(individual)
        actions = agent.act(states)
        embeddings.append(actions.flatten())
    embeddings = np.array(embeddings)
    gamma = load_median_gamma_act(env_id)
    sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=gamma)
    act_vendi_score = float(vendi.score_K(sym_matrix))

    metrics = {
        "population_size": len(population),
        "mean_objective": float(objectives.mean()),
        "min_objective": float(objectives.min()),
        "max_objective": float(objectives.max()),
        "gt_coverage": float(gt_coverage),  # Diversity
        "gt_size": int(gt_size),  # Diversity
        "gt_qd_score": float(gt_qd_score),  # Quality and Diversity
        "gt_mean_objective": float(gt_mean_objective),  # Quality
        "gt_max_objective": float(max_objective),  # Quality
        "vendi_scores": vendi_scores,  # Diversity
        "median_vendi_score": median_vendi_score,
        "median_act_vendi_score": act_vendi_score,
    }
    return metrics


def main():
    all_metrics = {
        "BipedalWalker-v3": {},
        "Ant-v5": {},
        "HalfCheetah-v5": {},
        "Hopper-v5": {},
        "Swimmer-v5": {},
        "Walker2d-v5": {},
    }
    for run_number in [1, 2, 3]:
        print(f"Evaluating {run_number} set of runs...")
        outputs_path = Path(f"./{run_number}_outputs")
        all_metrics = add_DvD_metrics(all_metrics, outputs_path)
        for logdir_path in tqdm(list(outputs_path.iterdir())):
            if logdir_path.is_dir():
                name = logdir_path.name
                # DvD policies must be evaluated separately and their results should be added
                #   to the outputs directory
                if name[0].isalpha() and name.split("_")[0] != "dvd":
                    print(f"Evaluating {name}")
                    with open(outputs_path / name / "config.yaml", "r") as file:
                        config = yaml.safe_load(file)

                    algo_name = config["algorithm"]["algo_name"]
                    env_id = config["env"]["env_id"]
                    env_kwargs = config["env"].get("env_kwargs", {})
                    state_dim, action_dim = (
                        config["env"]["state_dim"],
                        config["env"]["action_dim"],
                    )
                    min_score = config["env"]["min_score"]
                    population, agent_creator = load_population(
                        algo_name, f"{run_number}_outputs/{name}"
                    )

                    metrics = evaluate_population(
                        population,
                        agent_creator,
                        env_id,
                        env_kwargs,
                        32,  # n_evals
                        min_score,
                        state_dim,
                        action_dim,
                    )
                    if algo_name in all_metrics[env_id].keys():
                        all_metrics[env_id][algo_name].append(metrics)
                    else:
                        all_metrics[env_id][algo_name] = [metrics]

    # all_metrics[env_id][algo_name] is a list of length 3 that contains the eval metrics
    # Compute normalized mean objectives and weighted vendi scores
    for env_id in all_metrics.keys():
        min_score = {
            "BipedalWalker-v3": -200,
            "HalfCheetah-v5": -100,
            "Hopper-v5": -100,
            "Swimmer-v5": 0,
            "Walker2d-v5": -100,
            "Ant-v5": -1000,
        }[env_id]
        obj_upper_bound = (
            np.max(
                list(
                    [
                        np.max(
                            list(
                                [
                                    seed_metrics["mean_objective"]
                                    for seed_metrics in algo_metrics
                                ]
                            )
                        )
                        for algo_metrics in all_metrics[env_id].values()
                    ]
                )
            )
            - min_score
        )
        for algo_name in all_metrics[env_id].keys():
            for i in range(len(all_metrics[env_id][algo_name])):
                all_metrics[env_id][algo_name][i]["norm_mean_objective"] = max(
                    0,
                    float(
                        (
                            all_metrics[env_id][algo_name][i]["mean_objective"]
                            - min_score
                        )
                        / obj_upper_bound
                    ),
                )
                all_metrics[env_id][algo_name][i]["weighted_vendi_score"] = float(
                    all_metrics[env_id][algo_name][i]["median_vendi_score"]
                    * all_metrics[env_id][algo_name][i]["norm_mean_objective"]
                )
                all_metrics[env_id][algo_name][i]["act_weighted_vendi_score"] = float(
                    all_metrics[env_id][algo_name][i]["median_act_vendi_score"]
                    * all_metrics[env_id][algo_name][i]["norm_mean_objective"]
                )
                all_metrics[env_id][algo_name][i]["weighted_vendi_scores"] = dict()
                for gamma in all_metrics[env_id][algo_name][i]["vendi_scores"].keys():
                    all_metrics[env_id][algo_name][i]["weighted_vendi_scores"][
                        gamma
                    ] = float(
                        all_metrics[env_id][algo_name][i]["vendi_scores"][gamma]
                        * all_metrics[env_id][algo_name][i]["norm_mean_objective"]
                    )

    with open("evaluation_results.json", "w") as f:
        json.dump(all_metrics, f, indent=4)

    plot_weighted_vendi_scores(all_metrics)
    plot_weighted_vendi_scores_log(all_metrics)
    plot_metrics(all_metrics)


if __name__ == "__main__":
    main()
