import json
import os
import pickle
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional

import gymnasium as gym
import hydra
import numpy as np
import ray
import torch
import yaml
from omegaconf import DictConfig
from ribs.archives import GridArchive
from vendi_score import vendi

import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import gc
from src.algorithms.diayn import SACPolicy
from src.embeddings.rff import RFF
from src.qd.wrappers import (
    AntBehavioralWrapper,
    BipedalBehavioralWrapper,
    HalfCheetahBehavioralWrapper,
    HopperBehavioralWrapper,
    SwimmerBehavioralWrapper,
    WalkerBehavioralWrapper,
    ANT_MEASURE_NAMES,
    SWIMMER_MEASURE_NAMES,
)
from src.utils import Trajectory, collect_trajectories

GT_MEASURES = {
    "BipedalWalker-v3": [
        "left_contact_freq",
        "right_contact_freq",
    ],
    "Ant-v5": ANT_MEASURE_NAMES,
    "HalfCheetah-v5": [
        "back_foot_freq",
        "front_foot_freq",
    ],
    "Hopper-v5": ["foot_contact_freq"],
    "Swimmer-v5": SWIMMER_MEASURE_NAMES,
    "Walker2d-v5": ["right_contact_freq", "left_contact_freq"],
}

ENV_WRAPPERS = {
    "BipedalWalker-v3": BipedalBehavioralWrapper,
    "Ant-v5": AntBehavioralWrapper,
    "HalfCheetah-v5": HalfCheetahBehavioralWrapper,
    "Hopper-v5": HopperBehavioralWrapper,
    "Swimmer-v5": SwimmerBehavioralWrapper,
    "Walker2d-v5": WalkerBehavioralWrapper,
}


#########################
##### Visualization #####
#########################


def plot_metrics(metrics, output_dir="plots"):
    """
    Generates modern, informative plots for algorithm evaluation results and saves them as high quality PNG files.

    For each environment in the 'metrics' dictionary:
      - It creates a figure with 3 side-by-side subplots comparing:
          'gt_qd_score', 'weighted_vendi_score', and 'act_weighted_vendi_score'
      - It creates another plot that compares across algorithms for these metrics:
          'norm_mean_objective', 'median_vendi_score', 'median_act_vendi_score',
          'max_objective', and 'gt_size'

    The figures are saved into the specified output directory.

    Parameters:
      metrics (dict): A nested dictionary where each key is an environment identifier and
                      each value is a dictionary mapping algorithm names to their metric values.
      output_dir (str): Directory where the plots will be saved. Defaults to "plots".
    """
    # Create the output directory if it does not exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Set a modern Seaborn style for all plots
    sns.set_theme(style="whitegrid")

    # Define a fixed color palette for consistent colors across plots
    algorithms = sorted({algo for env_data in metrics.values() for algo in env_data})
    palette = sns.color_palette("viridis", len(algorithms))
    color_mapping = dict(zip(algorithms, palette))

    # Loop through each environment in the metrics dictionary
    for env, env_data in metrics.items():
        # Sort algorithm names for consistency
        algo_names = sorted(env_data.keys())

        ## First Figure: 3 Subplots for key metrics
        key_metrics = [
            "gt_qd_score",
            "weighted_vendi_score",
            "act_weighted_vendi_score",
        ]
        fig, axs = plt.subplots(1, len(key_metrics), figsize=(5 * len(key_metrics), 5))

        for i, metric in enumerate(key_metrics):
            # Gather values for the current metric from all algorithms
            values = [env_data[algo][metric] for algo in algo_names]
            # Plot using matplotlib's bar so we can set colors explicitly
            axs[i].bar(
                algo_names, values, color=[color_mapping[algo] for algo in algo_names]
            )
            axs[i].set_title(metric.replace("_", " ").title(), fontsize=12)
            axs[i].set_xlabel("Algorithm", fontsize=10)
            axs[i].set_ylabel("Score", fontsize=10)
            axs[i].tick_params(axis="x", rotation=45)

        fig.suptitle(
            f"Key Metrics Comparison for Environment: {env}", fontsize=14, y=1.05
        )
        fig.tight_layout()

        # Save the first figure as a high-quality PNG file
        key_metrics_file = os.path.join(output_dir, f"{env}_key_metrics.png")
        fig.savefig(key_metrics_file, dpi=300)
        plt.close(fig)

        ## Second Figure: Separate plots for each broader metric
        all_metrics = [
            "norm_mean_objective",
            "median_vendi_score",
            "median_act_vendi_score",
            "max_objective",
            "gt_size",
        ]

        # Create a DataFrame for the environment's data
        df = (
            pd.DataFrame.from_dict(env_data, orient="index")
            .reset_index()
            .rename(columns={"index": "Algorithm"})
        )

        # Loop over each metric to generate a separate plot
        for metric in all_metrics:
            plt.figure(figsize=(8, 6))
            sns.barplot(
                x="Algorithm",
                y=metric,
                data=df,
                order=algo_names,
                palette=color_mapping,
            )
            plt.title(
                f"{metric.replace('_', ' ').title()} Comparison for Environment: {env}",
                fontsize=14,
            )
            plt.xlabel("Algorithm", fontsize=12)
            plt.ylabel("Value", fontsize=12)
            plt.xticks(rotation=45)
            plt.tight_layout()

            # Save each individual metric plot as a high-quality PNG file
            metric_file = os.path.join(output_dir, f"{env}_{metric}.png")
            plt.savefig(metric_file, dpi=300)
            plt.close()


def plot_weighted_vendi_scores(all_metrics, save_path="weighted_vendi_scores.png"):
    """
    Plots weighted Vendi score against gamma for all algos and envs
    """
    # Collect all algorithm names across environments for a consistent color mapping
    algorithm_names = set()
    for env_data in all_metrics.values():
        algorithm_names.update(env_data.keys())
    algorithm_names = sorted(algorithm_names)

    # Dictionary mapping each algorithm to a color
    cmap = plt.get_cmap("tab10")
    colors = {algo: cmap(i % 10) for i, algo in enumerate(algorithm_names)}

    # Determine the number of environments and create subplots accordingly.
    # This code assumes there are 6 environments; adjust the grid dimensions if needed.
    num_envs = len(all_metrics)
    nrows, ncols = 2, 3
    fig, axes = plt.subplots(nrows, ncols, figsize=(18, 12))
    axes = axes.flatten()  # for easier iteration

    for idx, env in enumerate(sorted(all_metrics.keys())):
        ax = axes[idx]

        # For each algorithm, check if data for this environment exists and plot it
        for algo in algorithm_names:
            if algo in all_metrics[env]:
                scores_dict = all_metrics[env][algo]["weighted_vendi_scores"]
                # Sort gamma values for the x-axis
                gammas = sorted(scores_dict.keys(), key=lambda x: float(x))
                gamma_values = [float(g) for g in gammas]
                scores = [scores_dict[gamma] for gamma in gammas]
                ax.plot(
                    gamma_values,
                    scores,
                    label=algo,
                    color=colors[algo],
                    linewidth=2,
                    marker="o",
                )

        # Set plot titles, labels and grid
        ax.set_title(f"Environment {env}", fontsize=14)
        ax.set_xlabel("Gamma", fontsize=12)
        ax.set_ylabel("Weighted Vendi Score", fontsize=12)
        ax.grid(True)
        ax.legend(fontsize=10, loc="best")

    for extra_ax in axes[num_envs:]:
        fig.delaxes(extra_ax)

    fig.tight_layout()

    fig.savefig(save_path, dpi=300)
    print(f"Plot saved as {save_path}")
    plt.clf()


def plot_weighted_vendi_scores_log(
    all_metrics, save_path="weighted_vendi_scores_log.png"
):
    """
    Generate a figure containing subplots for each environment that compares the weighted
    vendi scores across a set of gamma values on a logarithmic scale for the x-axis.

    Parameters:
        all_metrics (dict):
            A nested dictionary where the keys are environment IDs. For each environment,
            the value is another dictionary keyed by algorithm names. For each algorithm,
            there is a dictionary with a key "weighted_vendi_scores" that maps gamma values
            to the corresponding score.
            Access via: all_metrics[env_id][algo_name]["weighted_vendi_scores"][gamma]

        save_path (str):
            Path (including filename) to save the final plot as a high-resolution PNG.
            Default is "weighted_vendi_scores_log.png".
    """

    algorithm_names = set()
    for env_data in all_metrics.values():
        algorithm_names.update(env_data.keys())
    algorithm_names = sorted(algorithm_names)

    cmap = plt.get_cmap("tab10")
    colors = {algo: cmap(i % 10) for i, algo in enumerate(algorithm_names)}

    num_envs = len(all_metrics)
    nrows, ncols = 2, 3
    fig, axes = plt.subplots(nrows, ncols, figsize=(18, 12))
    axes = axes.flatten()

    for idx, env in enumerate(sorted(all_metrics.keys())):
        ax = axes[idx]

        for algo in algorithm_names:
            if algo in all_metrics[env]:
                scores_dict = all_metrics[env][algo]["weighted_vendi_scores"]
                # Ensure numeric sorting: convert keys to float when sorting.
                gammas = sorted(scores_dict.keys(), key=lambda x: float(x))
                gamma_values = [float(g) for g in gammas]
                scores = [scores_dict[g] for g in gammas]
                ax.plot(
                    gamma_values,
                    scores,
                    label=algo,
                    color=colors[algo],
                    linewidth=2,
                    marker="o",
                )

        # Set the x-axis to a logarithmic scale
        ax.set_xscale("log")

        ax.set_title(f"Environment {env}", fontsize=14)
        ax.set_xlabel("Gamma (log scale)", fontsize=12)
        ax.set_ylabel("Weighted Vendi Score", fontsize=12)
        ax.grid(True, which="both", ls="--", lw=0.5)
        ax.legend(fontsize=10, loc="best")

    for extra_ax in axes[num_envs:]:
        fig.delaxes(extra_ax)

    fig.tight_layout()
    fig.savefig(save_path, dpi=300)
    print(f"Log-scaled plot saved as {save_path}")
    plt.clf()


##############################
##### Evaluation Utility #####
##############################


def agent_from_DictConfig(cfg: DictConfig):
    return hydra.utils.instantiate(cfg)


class SAC_Agent_Wrapper:
    def __init__(self, policy, n_skills):
        self.policy = policy
        self.n_skills = n_skills

    def from_numpy(self, skill_id):
        self.skill_id = skill_id
        self.skill = np.zeros(self.n_skills, dtype=np.float32)
        self.skill[skill_id] = 1.0
        return self

    @torch.no_grad()
    def act(self, state: np.ndarray):
        batched = True
        if state.ndim == 1:
            state = np.expand_dims(state, 0)
            batched = False
        skill_vec = np.tile(self.skill, state.shape[0]).reshape(
            (state.shape[0], self.n_skills)
        )

        state_skill = torch.FloatTensor(np.concatenate([state, skill_vec], axis=-1))
        action, _, _ = self.policy.sample_action(state_skill)
        if batched:
            return action.cpu().numpy()
        else:
            return action.cpu().numpy().flatten()


def create_sac_agent(state_dim, action_dim, hidden_dims, n_skills, state_dict):
    policy = SACPolicy(state_dim + n_skills, action_dim, hidden_dims)
    policy.load_state_dict(state_dict)
    return SAC_Agent_Wrapper(policy, n_skills)


@ray.remote
class FlexiblePolicyEvaluator:
    def __init__(
        self,
        env_id: str,
        env_kwargs: Dict,
        num_envs: int,
        agent_creator: Callable,
        wrappers: Optional[List[gym.Wrapper]] = None,
        measure_names: Optional[List[str]] = None,
    ):
        """Initialize policy evaluator.

        Args:
            env_id: Gymnasium environment ID
            num_envs: Number of parallel environments to run
            agent_creator: Function that creates an instance of the agent (nn.Module)
            wrappers: None or list of wrappers passed to the vector env constructor
            measure_names: None or list of handcrafted measures that will be extracted
        """
        torch.set_num_threads(1)
        self.agent_creator = agent_creator
        if wrappers and measure_names:
            self.include_measures = True
            self.measure_names = measure_names
            self.wrappers = wrappers
        else:
            self.include_measures = False
            self.measure_names = None
            self.wrappers = []

        self.envs = gym.make_vec(
            env_id,
            num_envs=num_envs,
            vectorization_mode=gym.VectorizeMode.ASYNC,
            wrappers=self.wrappers,
            **env_kwargs,
        )

    def evaluate_policy(
        self,
        policy_params: np.ndarray,
        n_trajectories: int,
    ) -> List[Trajectory]:
        """Evaluate a single policy."""
        policy = self.agent_creator().from_numpy(policy_params)
        if self.include_measures:
            return collect_trajectories(
                self.envs,
                policy,
                n_trajectories,
                self.measure_names,
            )
        else:
            return collect_trajectories(self.envs, policy, n_trajectories)


#############################
##### Loading From File #####
#############################


def add_DvD_metrics(all_metrics):
    with open("./outputs/dvd_logs.json") as f:
        dvd_logs = json.load(f)
    for env_id in all_metrics.keys():
        all_metrics[env_id]["dvd"] = dvd_logs[env_id]
    return all_metrics


def load_or_create_embedding_map(env_id, state_dim, action_dim):
    embedding_file_path = Path(f"./evaluation_data/embedding_map")
    if (embedding_file_path / f"{env_id}.pkl").is_file():
        with open(embedding_file_path / f"{env_id}.pkl", "rb") as f:
            embedding_map = pickle.load(f)
        print("Loaded RFF embedding map from file")
    else:
        os.makedirs(embedding_file_path, exist_ok=True)
        embedding_map = RFF(
            dim=400,
            state_dim=state_dim,
            action_dim=action_dim,
            kernel_width=None,
            normalize=True,
            gamma=0.999,
        )
        embedding_map.training = False
        with open(f"./evaluation_data/{env_id}_mean_std.pkl", "rb") as f:
            stats = pickle.load(f)
            embedding_map.normalizer.mean = torch.tensor(
                stats["mean"], dtype=torch.float32
            )
            embedding_map.normalizer.std = torch.tensor(
                stats["std"], dtype=torch.float32
            )
        with open(embedding_file_path / f"{env_id}.pkl", "wb") as f:
            pickle.dump(embedding_map, f)
        print("Initialized new embedding map")
    return embedding_map


def load_median_gamma(env_id):
    with open("evaluation_data/gammas.json") as f:
        gammas = json.load(f)
    return gammas[env_id]["rff"]["median"]


def load_median_gamma_act(env_id):
    with open("evaluation_data/gammas.json") as f:
        gammas = json.load(f)
    return gammas[env_id]["act"]["median"]


def load_population(algo_name, path):
    if algo_name in ["auto_qd", "regular_qd", "aurora", "lstm_aurora"]:
        with open(path + "/checkpoints" + "/final.pkl", "rb") as f:
            ckpt = pickle.load(f)
        archive = ckpt["archive"]
        solutions: np.ndarray = archive.data("solution")  # N x sol_dim
        agent_cfg: DictConfig = ckpt["agent_cfg"]
        agent_creator = partial(agent_from_DictConfig, cfg=agent_cfg)
        return solutions, agent_creator
    elif algo_name in ["diayn", "smerl"]:
        ckpt = torch.load(path + "/checkpoints/final.pt")
        agent_creator = partial(
            create_sac_agent,
            state_dim=ckpt["state_dim"],
            action_dim=ckpt["action_dim"],
            hidden_dims=ckpt["hidden_dims"],
            n_skills=ckpt["n_skills"],
            state_dict=ckpt["policy"],
        )
        solutions = np.arange(ckpt["n_skills"])
        return solutions, agent_creator
    else:
        raise ValueError(f"algo_name {algo_name} not recognized")


###################################
##### Main Evaluation Methods #####
###################################


def chunked_rbf_kernel_gpu(X: np.ndarray, gamma: float, chunk_size=1024):
    """
    Efficient RBF kernel computation using GPU and chunking.
    Note that |x-y|^2 = |x|^2 + |y|^2 - 2<x, y>
    """
    n = X.shape[0]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    X_gpu = torch.as_tensor(X, device=device, dtype=torch.float32)
    K = torch.empty((n, n), device=device, dtype=torch.float32)

    X_norms = torch.sum(X_gpu**2, dim=1)

    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        X_i = X_gpu[i:end_i]

        for j in range(0, n, chunk_size):
            end_j = min(j + chunk_size, n)
            X_j = X_gpu[j:end_j]

            dots = torch.mm(X_i, X_j.t())
            norms_i = X_norms[i:end_i].unsqueeze(1)
            norms_j = X_norms[j:end_j].unsqueeze(0)
            sq_dists = torch.clamp(
                norms_i + norms_j - 2 * dots, min=0.0
            )  # Prevent negatives

            K[i:end_i, j:end_j] = torch.exp(-gamma * sq_dists)

    K = 0.5 * (K + K.T)  # Ensure symmetry
    K = K.cpu().numpy()
    return K


def evaluate_population(
    population: List,
    agent_creator: Callable,
    env_id: str,
    env_kwargs: Dict,
    n_evals: int,
    min_score: float,  # For computing QD score
    state_dim: int,
    action_dim: int,
):
    """
    Given a population, we do two types of analysis to determine the quality and diversity
    of them collectively.

    1. Ground Truth Measures: Create an archive using known ground truth measures of diversity
        and put the policies in that archive. By reporting the QD score, coverage, mean
        and max objectives we understand how good the archive is based on human notions
        of diversity.

    2. Occupancy Measure Based Analysis: We use embeddings similar to those employed by
        AutoQD to obtain an embedding for each policy. The same embedding map is used to
        compare all populations (and it is different from the one used by AutoQD to train)
        We use the Vendi score and the weighted Vendi score to evaluate each population
        based on the diversity of their embeddings and their quality..
    """
    print(f"Evaluating population of size: {len(population)}")

    if not ray.is_initialized():
        ray.init()
    cpus_per_worker = 2
    num_workers = int(max(1, (ray.cluster_resources()["CPU"]) // cpus_per_worker))
    print(f"Using {num_workers} workers each with {cpus_per_worker} CPUs")

    evaluators = [
        FlexiblePolicyEvaluator.options(num_cpus=cpus_per_worker).remote(
            env_id=env_id,
            env_kwargs=env_kwargs,
            num_envs=4,
            agent_creator=agent_creator,
            wrappers=[ENV_WRAPPERS[env_id]],
            measure_names=GT_MEASURES[env_id],
        )
        for _ in range(num_workers)
    ]

    objectives, measures, embeddings = [], [], []
    embedding_map = load_or_create_embedding_map(env_id, state_dim, action_dim)

    for i in range(0, len(population), 10 * num_workers):
        batch = population[i : i + 10 * num_workers]
        futures = [
            evaluators[j % num_workers].evaluate_policy.remote(individual, n_evals)
            for j, individual in enumerate(batch)
        ]
        batch_trajectories = ray.get(futures)

        for trajs in batch_trajectories:
            obj = np.mean([t.rewards.sum() for t in trajs])
            meas = np.array([t.measures for t in trajs]).mean(0)
            embedding = embedding_map.embed_trajectories(trajs)

            objectives.append(obj)
            measures.append(meas)
            embeddings.append(embedding)

        # call garbage collector to reclaim memory
        gc.collect()

    objectives = np.array(objectives)
    measures = np.array(measures)
    embeddings = np.array(embeddings)

    # Calculate stats on archive with GT measures
    archive = GridArchive(
        solution_dim=1,
        dims=list([20 for _ in range(len(GT_MEASURES[env_id]))]),
        ranges=tuple(((0, 1) for _ in range(len(GT_MEASURES[env_id])))),
        qd_score_offset=min_score,
    )
    archive.add(np.zeros((len(objectives), 1)), objectives, measures)
    gt_size = len(archive)
    gt_coverage = archive.stats.coverage
    gt_qd_score = archive.stats.qd_score
    gt_mean_objective = archive.stats.obj_mean
    max_objective = objectives.max()

    # 2. Calculate Vendi Score and qVS with RFF embeddings
    # get the kernel matrix
    GAMMAS = [0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    vendi_scores = {}
    for gamma in GAMMAS:
        sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=gamma)
        rff_V = vendi.score_K(sym_matrix)  # Vendi score
        vendi_scores[gamma] = float(rff_V)
        del sym_matrix

    median_gamma = load_median_gamma(env_id)
    sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=median_gamma)
    median_vendi_score = float(vendi.score_K(sym_matrix))  # Vendi score
    vendi_scores[median_gamma] = median_vendi_score
    del sym_matrix

    # 3. Calculate Vendi Score and qVS with ACT embeddings
    with open(f"./evaluation_data/{env_id}.pkl", "rb") as f:
        states = pickle.load(f)  # np array of shape (n, state_dim)
    agent = agent_creator()
    embeddings = []
    for individual in population:
        agent.from_numpy(individual)
        actions = agent.act(states)
        embeddings.append(actions.flatten())
    embeddings = np.array(embeddings)
    gamma = load_median_gamma_act(env_id)
    sym_matrix = chunked_rbf_kernel_gpu(embeddings, gamma=gamma)
    act_vendi_score = float(vendi.score_K(sym_matrix))

    metrics = {
        "population_size": len(population),
        "mean_objective": float(objectives.mean()),
        "min_objective": float(objectives.min()),
        "max_objective": float(objectives.max()),
        "gt_coverage": float(gt_coverage),  # Diversity
        "gt_size": int(gt_size),  # Diversity
        "gt_qd_score": float(gt_qd_score),  # Quality and Diversity
        "gt_mean_objective": float(gt_mean_objective),  # Quality
        "gt_max_objective": float(max_objective),  # Quality
        "vendi_scores": vendi_scores,  # Diversity
        "median_vendi_score": median_vendi_score,
        "median_act_vendi_score": act_vendi_score,
    }
    return metrics


def main():
    all_metrics = {
        "BipedalWalker-v3": {},
        "Ant-v5": {},
        "HalfCheetah-v5": {},
        "Hopper-v5": {},
        "Swimmer-v5": {},
        "Walker2d-v5": {},
    }
    outputs_path = Path("./outputs")
    for logdir_path in outputs_path.iterdir():
        if logdir_path.is_dir():
            name = logdir_path.name
            # DvD policies must be evaluated separately and their results should be added
            #   to the outputs directory
            if name[0].isalpha() and name.split("_")[0] != "dvd":
                print(f"Evaluating {name}")
                with open(f"outputs/{name}/config.yaml", "r") as file:
                    config = yaml.safe_load(file)

                algo_name = config["algorithm"]["algo_name"]
                env_id = config["env"]["env_id"]
                env_kwargs = config["env"].get("env_kwargs", {})
                state_dim, action_dim = (
                    config["env"]["state_dim"],
                    config["env"]["action_dim"],
                )
                min_score = config["env"]["min_score"]
                population, agent_creator = load_population(
                    algo_name, f"outputs/{name}"
                )

                metrics = evaluate_population(
                    population,
                    agent_creator,
                    env_id,
                    env_kwargs,
                    32,  # n_evals
                    min_score,
                    state_dim,
                    action_dim,
                )
                all_metrics[env_id][algo_name] = metrics

    all_metrics = add_DvD_metrics(all_metrics)
    for env_id in all_metrics.keys():
        min_score = {
            "BipedalWalker-v3": -200,
            "HalfCheetah-v5": -100,
            "Hopper-v5": -100,
            "Swimmer-v5": 0,
            "Walker2d-v5": -100,
            "Ant-v5": -1000,
        }[env_id]
        max_mean_obj = np.max(
            list(
                [
                    algo_metrics["mean_objective"]
                    for algo_metrics in all_metrics[env_id].values()
                ]
            )
        )
        for algo_name in all_metrics[env_id].keys():
            all_metrics[env_id][algo_name]["norm_mean_objective"] = max(
                0,
                float(
                    (all_metrics[env_id][algo_name]["mean_objective"] - min_score)
                    / (max_mean_obj - min_score)
                ),
            )
            all_metrics[env_id][algo_name]["weighted_vendi_score"] = float(
                all_metrics[env_id][algo_name]["median_vendi_score"]
                * all_metrics[env_id][algo_name]["norm_mean_objective"]
            )
            all_metrics[env_id][algo_name]["act_weighted_vendi_score"] = float(
                all_metrics[env_id][algo_name]["median_act_vendi_score"]
                * all_metrics[env_id][algo_name]["norm_mean_objective"]
            )
            all_metrics[env_id][algo_name]["weighted_vendi_scores"] = dict()
            for gamma in all_metrics[env_id][algo_name]["vendi_scores"].keys():
                all_metrics[env_id][algo_name]["weighted_vendi_scores"][gamma] = float(
                    all_metrics[env_id][algo_name]["vendi_scores"][gamma]
                    * all_metrics[env_id][algo_name]["norm_mean_objective"]
                )

    with open("evaluation_results.json", "w") as f:
        json.dump(all_metrics, f, indent=4)

    plot_weighted_vendi_scores(all_metrics)
    plot_weighted_vendi_scores_log(all_metrics)
    plot_metrics(all_metrics)


if __name__ == "__main__":
    main()
