import os
import json
import wandb
import numpy as np
from pathlib import Path
from tqdm import trange
from collections import defaultdict
from k_level_policy_gradients.src.utils.logger import Logger
from k_level_policy_gradients.src.utils.dataset import CollectDataset
from k_level_policy_gradients.src.utils.serialization import Serializable
from k_level_policy_gradients.src.environments.smac.smac_env import SMAC
from k_level_policy_gradients.src.environments.mamujoco.mamujoco_env import Mamujoco
from k_level_policy_gradients.src.environments.toy.attraction import Attraction
from k_level_policy_gradients.src.environments.toy.plug import Plug
from k_level_policy_gradients.src.utils.dataset import (
    compute_J_all_agents,
    smac_battles_won,
    compute_action_norms_all_agents,
    compute_episode_lengths,
)
from k_level_policy_gradients.src.agents.setup_agent import setup_agent
from k_level_policy_gradients.src.core.multi_agent_core import MultiAgentCore
from k_level_policy_gradients.src.core.multi_agent_core_mixer import MultiAgentCoreMixer
from k_level_policy_gradients.src.core.multi_agent_core_shared import (
    MultiAgentCoreShared,
)
from k_level_policy_gradients.src.core.multi_agent_core_shared_mixer import (
    MultiAgentCoreSharedMixer,
)
from k_level_policy_gradients.src.core.multi_agent_core_hidden import (
    MultiAgentCoreHidden,
)
from k_level_policy_gradients.src.core.multi_agent_core_hidden_mixer import (
    MultiAgentCoreHiddenMixer,
)
from k_level_policy_gradients.src.core.multi_agent_core_hidden_shared import (
    MultiAgentCoreHiddenShared,
)
from k_level_policy_gradients.src.core.multi_agent_core_hidden_shared_mixer import (
    MultiAgentCoreHiddenSharedMixer,
)
from k_level_policy_gradients.src.utils.config_loader import load_config


class MultiAgentExperiment:
    def __init__(
        self,
        env: str = "",
        agent: str = "",
        log_wandb: bool = False,
        log_best_agents: bool = False,
        log_online_data: bool = False,
        note: str = "",
        env_params: dict = {},
        env_run_params: dict = {},
        agent_params: dict = {},
        results_dir: str = "",
        seed: str = 0,
    ):
        """
        Constructor.
        """
        self.env = env
        self.agent = agent
        self.log_wandb = log_wandb
        self.log_best_agents = log_best_agents
        self.log_online_data = log_online_data
        self.env_params = env_params
        self.env_run_params = env_run_params
        self.agent_params = agent_params
        self.results_dir = results_dir
        self.seed = seed

        if self.env_run_params["n_steps_per_fit"] == "None":
            self.env_run_params["n_steps_per_fit"] = None
        if self.env_run_params["n_episodes_per_fit"] == "None":
            self.env_run_params["n_episodes_per_fit"] = None

        # Logging
        self.exp_logger = Logger(
            log_name="Experiment",
            results_dir=Path(self.results_dir) / "Logging",
            log_console=True,
            seed=self.seed,
            console_log_level=30,
        )
        self.log_params()
        if self.log_wandb:
            dir_name = results_dir.split("/")[-4]
            if env == "sc2":
                wandb.init(
                    project="k-level-policy-gradients",
                    group=dir_name,
                    name=f"{self.seed}",
                    config={
                        "agent": self.agent,
                        "env": self.env,
                        "map_name": self.env_params["map_name"],
                    },
                )
            elif env == "mamujoco":
                wandb.init(
                    project="k-level-policy-gradients",
                    group=dir_name,
                    name=f"{dir_name}_{self.seed}",
                    config={
                        "agent": self.agent,
                        "env": self.env,
                        "scenario": self.env_params["scenario"],
                        "partitioning": self.env_params["partitioning"],
                    },
                )
            elif env == "attraction":
                wandb.init(
                    project="k-level-policy-gradients",
                    group=dir_name,
                    name=f"{dir_name}_{self.seed}",
                    config={
                        "agent": self.agent,
                        "env": self.env,
                        "n_agents": self.env_params["n_agents"],
                    },
                )
            elif env == "plug":
                wandb.init(
                    project="k-level-policy-gradients",
                    group=dir_name,
                    name=f"{dir_name}_{self.seed}",
                    config={
                        "agent": self.agent,
                        "env": self.env,
                        "n_agents": self.env_params["n_agents"],
                    },
                )
            else:
                raise NotImplementedError

        if note != "":
            note_file = Path(self.results_dir) / "note.md"
            f = open(note_file, "w")
            f.write(note)
            f.close()

        if self.env == "sc2":
            self.evaluate_performance = self._evaluate_performance_sc2
            self.evaluate_performance_online = self._evaluate_performance_sc2_online
            self.log_data = self._log_data_sc2

        if self.env == "mamujoco":
            self.evaluate_performance = self._evaluate_performance_mamujoco
            self.evaluate_performance_online = (
                self._evaluate_performance_online_mamujoco
            )
            self.log_data = self._log_data_mamujoco

        if self.env == "attraction":
            self.evaluate_performance = self._evaluate_performance_attraction
            self.evaluate_performance_online = (
                self._evaluate_performance_online_attraction
            )
            self.log_data = self._log_data_attraction

        if self.env == "plug":
            self.evaluate_performance = self._evaluate_performance_plug
            self.evaluate_performance_online = self._evaluate_performance_online_plug
            self.log_data = self._log_data_plug

    def log_params(self):
        """
        Put all the parameters into the Experiment_<seed> log file and as a json file
        """
        log_params_dict = {}
        log_params_dict.update(self.__dict__)
        log_params_dict.pop("exp_logger", None)
        msg = "\n" + "".join(f"{k}: {v}\n" for (k, v) in log_params_dict.items())
        self.exp_logger.info(msg)

        # Save as json
        json_path = Path(self.results_dir) / "Logging" / "data.json"
        with open(json_path, "w") as fp:
            json.dump(log_params_dict, fp, indent=4)

    def train_agents(self):
        print("Training agents...check logs for progress")

        core, agents, _ = self.setup_experiment()

        # Initial random transitions
        initial_random_transitions = self.env_run_params.get(
            "initial_random_transitions", 0
        )
        if initial_random_transitions > 0:
            core.set_random_mode()
            core.learn(
                n_steps=initial_random_transitions,
                n_steps_per_fit_per_agent=[
                    initial_random_transitions for _ in range(core.mdp.info.n_agents)
                ],  # no training
                quiet=False,
                render=self.env_run_params["bool_render"],
            )

        # Log online data
        if self.log_online_data:
            core.set_callbacks_step("online", CollectDataset())

        # Train agents
        performance_data_dict_per_iteration = defaultdict(list)
        for i in trange(self.env_run_params["n_total_iterations"], leave=False):
            core.set_training_mode()
            core.learn(
                n_steps=self.env_run_params["n_steps_per_iteration"],
                n_steps_per_fit_per_agent=self.get_n_steps_per_fit_per_agent(
                    core.mdp.info.n_agents
                ),
                n_episodes_per_fit_per_agent=self.get_n_episodes_per_fit_per_agent(
                    core.mdp.info.n_agents
                ),
                quiet=False,
                render=self.env_run_params["bool_render"],
            )
            # Online performance
            performance_data_dict_online = self.evaluate_performance_online(core)
            # Offline performance
            performance_data_dict = self.evaluate_performance(
                core, self.env_run_params["n_evaluation_episodes"]
            )
            # Clear core buffers
            core.clear_all_callbacks_step()

            # Log data
            self.log_data(i, performance_data_dict, performance_data_dict_online)
            for k, v in performance_data_dict.items():
                performance_data_dict_per_iteration[f"{k}_per_iteration"].append(v)
            if self.log_online_data:
                for k, v in performance_data_dict_online.items():
                    performance_data_dict_per_iteration[f"{k}_per_step"].extend(v)

            # Log best agents
            if self.log_best_agents:
                self.exp_logger.log_best_agents_overall(
                    agents,
                    performance_data_dict["mean_return_total"],
                )

        # Save data
        self.save_data(
            Path(self.results_dir) / "Training", **performance_data_dict_per_iteration
        )

        # Save agents
        agent_ids = np.arange(core.mdp.info.n_agents)
        self.save_agents(
            agents,
            agent_ids,
            Path(self.results_dir) / "Training",
            self.seed,
            full_save=False,
        )

        return core

    def evaluate(self, core):
        """
        Evaluate the performance of the agents in the core
        """

        performance_dict = self.evaluate_performance(
            core, self.env_run_params["n_evaluation_episodes"]
        )

        # Save data
        self.save_data(Path(self.results_dir) / "Evaluation", **performance_dict)

    def _evaluate_performance_sc2(self, core, n_episodes):
        """
        Evaluate the performance of the agents

        n_episodes: number of episodes to evaluate each agent for
        core: core to use for evaluation (should already contain the agents and mdp)
        """

        # Get evaluation data
        core.set_testing_mode()
        core.set_callbacks_step("offline", CollectDataset())
        core.evaluate(n_episodes=n_episodes, render=self.env_run_params["bool_render"])
        evaluation_data = core.callbacks_step["offline"].get()
        core.remove_callbacks_step("offline")
        core.set_training_mode()

        sample_data = [datapoint["sample"] for datapoint in evaluation_data]
        info_data = [datapoint["info"] for datapoint in evaluation_data]

        # Returns
        return_per_episode_per_agent = compute_J_all_agents(sample_data)
        mean_return_per_agent = np.mean(return_per_episode_per_agent, axis=1)
        mean_return_total = np.mean(mean_return_per_agent)

        # Success rate
        battles_won = smac_battles_won(sample_data, info_data)
        success_rate = sum(battles_won) / len(battles_won)

        return {
            "mean_return_per_agent": mean_return_per_agent,
            "mean_return_total": mean_return_total,
            "success_rate": success_rate,
        }

    def _evaluate_performance_sc2_online(self, core):
        return {}

    def _log_data_sc2(self, i, performance_data_dict, performance_data_dict_online):
        """
        Log the performance data
        """
        mean_return_total = performance_data_dict["mean_return_total"]
        success_rate = performance_data_dict["success_rate"]
        msg_performance = f"It. {i} Total: {np.round(mean_return_total)} SR: {round(success_rate * 100, 2)}%"
        self.exp_logger.info(msg_performance)

        if self.log_wandb:
            wandb.log(
                {"success_rate": success_rate, "mean_return_total": mean_return_total}
            )

    def _evaluate_performance_mamujoco(self, core, n_episodes):
        """
        Evaluate the performance of the agents

        n_episodes: number of episodes to evaluate each agent for
        core: core to use for evaluation (should already contain the agents and mdp)
        """

        # Get evaluation data
        core.set_testing_mode()
        core.set_callbacks_step("offline", CollectDataset())
        core.evaluate(n_episodes=n_episodes, render=self.env_run_params["bool_render"])
        evaluation_data = core.callbacks_step["offline"].get()
        core.remove_callbacks_step("offline")
        core.set_training_mode()

        sample_data = [datapoint["sample"] for datapoint in evaluation_data]

        # Returns
        return_per_episode_per_agent = compute_J_all_agents(sample_data)
        mean_return_per_agent = np.mean(return_per_episode_per_agent, axis=1)
        mean_return_total = np.mean(mean_return_per_agent)

        # Episode lengths
        episode_lengths = compute_episode_lengths(sample_data)
        mean_episode_length = np.mean(episode_lengths)

        return {
            "mean_return_per_agent": mean_return_per_agent,
            "mean_return_total": mean_return_total,
            "mean_episode_length": mean_episode_length,
        }

    def _evaluate_performance_online_mamujoco(self, core):
        if self.log_online_data:
            online_data = core.callbacks_step["online"].get()
            sample_data = [datapoint["sample"] for datapoint in online_data]
            info_data = [datapoint["info"] for datapoint in online_data]

            action_norms = np.mean(compute_action_norms_all_agents(sample_data), axis=0)
            actor_losses = [
                info["actor_loss"] for info in info_data if "actor_loss" in info
            ]
            critic_losses = [
                info["critic_loss"] for info in info_data if "critic_loss" in info
            ]
            return {
                "action norms": action_norms,
                "actor_losses": actor_losses,
                "critic_losses": critic_losses,
            }
        else:
            return {}

    def _log_data_mamujoco(
        self, i, performance_data_dict, performance_data_dict_online
    ):
        """
        Log the performance data
        """
        mean_return_total = performance_data_dict["mean_return_total"]
        mean_episode_length = performance_data_dict["mean_episode_length"]
        if self.log_online_data:
            actor_losses = performance_data_dict_online["actor_losses"]
            critic_losses = performance_data_dict_online["critic_losses"]
            action_norms = performance_data_dict_online["action norms"]
            msg_performance = f"It. {i} Total: {np.round(mean_return_total)} Avg. Ep. Length: {np.round(mean_episode_length)} Actor loss: {np.round(np.mean(actor_losses), 2)} Critic loss: {np.round(np.mean(critic_losses), 2)} Action norms: {np.round(np.mean(action_norms), 2)}"
        else:
            msg_performance = f"It. {i} Total: {np.round(mean_return_total)} Avg. Ep. Length: {np.round(mean_episode_length)}"
        self.exp_logger.info(msg_performance)

        if self.log_wandb:
            if self.log_online_data:
                wandb.log(
                    {
                        "mean_return_total": mean_return_total,
                        "actor_losses": np.mean(actor_losses),
                        "critic_losses": np.mean(critic_losses),
                        "action_norms": np.mean(action_norms),
                    }
                )
            else:
                wandb.log({"mean_return_total": mean_return_total})

    def _evaluate_performance_attraction(self, core, n_episodes):
        """
        Evaluate the performance of the agents

        n_episodes: number of episodes to evaluate each agent for
        core: core to use for evaluation (should already contain the agents and mdp)
        """

        # Get evaluation data
        core.set_testing_mode()
        core.set_callbacks_step("offline", CollectDataset())
        core.evaluate(n_episodes=n_episodes, render=self.env_run_params["bool_render"])
        evaluation_data = core.callbacks_step["offline"].get()
        core.remove_callbacks_step("offline")
        core.set_training_mode()

        sample_data = [datapoint["sample"] for datapoint in evaluation_data]

        # Returns
        return_per_episode_per_agent = compute_J_all_agents(sample_data)
        mean_return_per_agent = np.mean(return_per_episode_per_agent, axis=1)
        mean_return_total = np.mean(mean_return_per_agent)

        # Episode lengths
        episode_lengths = compute_episode_lengths(sample_data)
        mean_episode_length = np.mean(episode_lengths)

        return {
            "mean_return_per_agent": mean_return_per_agent,
            "mean_return_total": mean_return_total,
            "mean_episode_length": mean_episode_length,
        }

    def _evaluate_performance_online_attraction(self, core):
        if self.log_online_data:
            online_data = core.callbacks_step["online"].get()
            sample_data = [datapoint["sample"] for datapoint in online_data]
            info_data = [datapoint["info"] for datapoint in online_data]

            action_norms = np.mean(compute_action_norms_all_agents(sample_data), axis=0)
            actor_losses = [
                info["actor_loss"] for info in info_data if "actor_loss" in info
            ]
            critic_losses = [
                info["critic_loss"] for info in info_data if "critic_loss" in info
            ]
            return {
                "action norms": action_norms,
                "actor_losses": actor_losses,
                "critic_losses": critic_losses,
            }
        else:
            return {}

    def _log_data_attraction(
        self, i, performance_data_dict, performance_data_dict_online
    ):
        """
        Log the performance data
        """
        mean_return_total = performance_data_dict["mean_return_total"]
        mean_episode_length = performance_data_dict["mean_episode_length"]
        if self.log_online_data:
            actor_losses = performance_data_dict_online["actor_losses"]
            critic_losses = performance_data_dict_online["critic_losses"]
            action_norms = performance_data_dict_online["action norms"]
            msg_performance = f"It. {i} Total: {np.round(mean_return_total)} Avg. Ep. Length: {np.round(mean_episode_length)} Actor loss: {np.round(np.mean(actor_losses), 2)} Critic loss: {np.round(np.mean(critic_losses), 2)} Action norms: {np.round(np.mean(action_norms), 2)}"
        else:
            msg_performance = f"It. {i} Total: {np.round(mean_return_total)} Avg. Ep. Length: {np.round(mean_episode_length)}"
        self.exp_logger.info(msg_performance)

        if self.log_wandb:
            if self.log_online_data:
                wandb.log(
                    {
                        "mean_return_total": mean_return_total,
                        "actor_losses": np.mean(actor_losses),
                        "critic_losses": np.mean(critic_losses),
                        "action_norms": np.mean(action_norms),
                    }
                )
            else:
                wandb.log({"mean_return_total": mean_return_total})

    def _evaluate_performance_plug(self, core, n_episodes):
        """
        Evaluate the performance of the agents

        n_episodes: number of episodes to evaluate each agent for
        core: core to use for evaluation (should already contain the agents and mdp)
        """

        # Get evaluation data
        core.set_testing_mode()
        core.set_callbacks_step("offline", CollectDataset())
        core.evaluate(n_episodes=n_episodes, render=self.env_run_params["bool_render"])
        evaluation_data = core.callbacks_step["offline"].get()
        core.remove_callbacks_step("offline")
        core.set_training_mode()

        sample_data = [datapoint["sample"] for datapoint in evaluation_data]

        # Returns
        return_per_episode_per_agent = compute_J_all_agents(sample_data)
        mean_return_per_agent = np.mean(return_per_episode_per_agent, axis=1)
        mean_return_total = np.mean(mean_return_per_agent)

        # Episode lengths
        episode_lengths = compute_episode_lengths(sample_data)
        mean_episode_length = np.mean(episode_lengths)

        return {
            "mean_return_per_agent": mean_return_per_agent,
            "mean_return_total": mean_return_total,
            "mean_episode_length": mean_episode_length,
        }

    def _evaluate_performance_online_plug(self, core):
        if self.log_online_data:
            online_data = core.callbacks_step["online"].get()
            sample_data = [datapoint["sample"] for datapoint in online_data]
            info_data = [datapoint["info"] for datapoint in online_data]

            action_norms = np.mean(compute_action_norms_all_agents(sample_data), axis=0)
            actor_losses = [
                info["actor_loss"] for info in info_data if "actor_loss" in info
            ]
            critic_losses = [
                info["critic_loss"] for info in info_data if "critic_loss" in info
            ]
            return {
                "action norms": action_norms,
                "actor_losses": actor_losses,
                "critic_losses": critic_losses,
            }
        else:
            return {}

    def _log_data_plug(self, i, performance_data_dict, performance_data_dict_online):
        """
        Log the performance data
        """
        mean_return_total = performance_data_dict["mean_return_total"]
        mean_episode_length = performance_data_dict["mean_episode_length"]
        if self.log_online_data:
            actor_losses = performance_data_dict_online["actor_losses"]
            critic_losses = performance_data_dict_online["critic_losses"]
            action_norms = performance_data_dict_online["action norms"]
            msg_performance = f"It. {i} Total: {np.round(mean_return_total)} Avg. Ep. Length: {np.round(mean_episode_length)} Actor loss: {np.round(np.mean(actor_losses), 2)} Critic loss: {np.round(np.mean(critic_losses), 2)} Action norms: {np.round(np.mean(action_norms), 2)}"
        else:
            msg_performance = f"It. {i} Total: {np.round(mean_return_total)} Avg. Ep. Length: {np.round(mean_episode_length)}"
        self.exp_logger.info(msg_performance)

        if self.log_wandb:
            if self.log_online_data:
                wandb.log(
                    {
                        "mean_return_total": mean_return_total,
                        "actor_losses": np.mean(actor_losses),
                        "critic_losses": np.mean(critic_losses),
                        "action_norms": np.mean(action_norms),
                    }
                )
            else:
                wandb.log({"mean_return_total": mean_return_total})

    def setup_experiment(self):
        mdp = self.provide_mdp()
        agents = []
        loggers = []

        # With mixer
        if self.agent in [
            "maddpg",
            "kmaddpg",
            "maddpg_discrete",
            "qmix",
            "comix",
            "facmac",
            "kfacmac",
            "facmac_continuous",
            "kfacmac_continuous",
        ]:
            # Mixer agent params transfer to host agents
            dir_path = os.path.dirname(os.path.realpath(__file__))
            parent_dir_path = os.path.dirname(dir_path)
            host_agent = self.agent_params.get("host_agent")
            host_agent_params = load_config(f"{parent_dir_path}/configs/agent.yaml")[
                host_agent
            ]
            host_agent_params["use_mixer"] = True
            host_agent_params["obs_last_action"] = self.agent_params["obs_last_action"]
            host_agent_params["critic_obs_last_action"] = self.agent_params.get(
                "critic_obs_last_action", True
            )
            host_agent_params["critic_agent_encoding"] = self.agent_params.get(
                "critic_agent_encoding", True
            )
            host_agent_params["centralized_critic"] = self.agent_params.get(
                "centralized_critic", True
            )
            host_agent_params["use_cuda"] = self.env_run_params["use_cuda"]
            host_agent_params["share_agent_params"] = self.agent_params[
                "share_agent_params"
            ]

            # Setup mixer agents
            if host_agent_params.get("share_agent_params"):
                primary_agent = setup_agent(
                    agent=host_agent,
                    mdp_info=mdp.info,
                    idx_agent=0,
                    agent_params=host_agent_params,
                )
                primary_agent_logger = Logger(
                    log_name="0",
                    results_dir=Path(self.results_dir) / "Logging",
                    log_console=True,
                    seed=self.seed,
                    console_log_level=30,
                )
                agents.append(primary_agent)
                loggers.append(primary_agent_logger)
                for idx in range(1, mdp.info.n_agents):
                    agent = setup_agent(
                        agent=host_agent,
                        mdp_info=mdp.info,
                        idx_agent=idx,
                        agent_params=host_agent_params,
                        primary_agent=primary_agent,
                    )
                    logger = Logger(
                        log_name=f"{idx}",
                        results_dir=Path(self.results_dir) / "Logging",
                        log_console=True,
                        seed=self.seed,
                        console_log_level=30,
                    )
                    agents.append(agent)
                    loggers.append(logger)
            else:
                for idx in range(mdp.info.n_agents):
                    agent = setup_agent(
                        agent=host_agent,
                        mdp_info=mdp.info,
                        idx_agent=idx,
                        agent_params=host_agent_params,
                    )
                    logger = Logger(
                        log_name=f"{idx}",
                        results_dir=Path(self.results_dir) / "Logging",
                        log_console=True,
                        seed=self.seed,
                        console_log_level=30,
                    )
                    agents.append(agent)
                    loggers.append(logger)
            # Mixer
            mixer = setup_agent(
                agent=self.agent,
                mdp_info=mdp.info,
                idx_agent=-1,
                agent_params=self.agent_params,
                host_agents=agents,
            )
            if "gru" in host_agent:
                if host_agent_params.get("share_agent_params"):
                    core = MultiAgentCoreHiddenSharedMixer(
                        agents=agents, mdp=mdp, mixer=mixer
                    )
                else:
                    core = MultiAgentCoreHiddenMixer(agents, mdp, mixer)
            else:
                if host_agent_params.get("share_agent_params"):
                    core = MultiAgentCoreSharedMixer(
                        agents=agents, mdp=mdp, mixer=mixer
                    )
                else:
                    core = MultiAgentCoreMixer(agents=agents, mdp=mdp, mixer=mixer)
        else:
            if self.agent_params.get("share_agent_params"):
                primary_agent = setup_agent(
                    agent=self.agent,
                    mdp_info=mdp.info,
                    idx_agent=0,
                    agent_params=self.agent_params,
                )
                primary_agent_logger = Logger(
                    log_name="0",
                    results_dir=Path(self.results_dir) / "Logging",
                    log_console=True,
                    seed=self.seed,
                    console_log_level=30,
                )
                agents.append(primary_agent)
                loggers.append(primary_agent_logger)
                for idx in range(1, mdp.info.n_agents):
                    agent = setup_agent(
                        agent=self.agent,
                        mdp_info=mdp.info,
                        idx_agent=idx,
                        agent_params=self.agent_params,
                        primary_agent=primary_agent,
                    )
                    logger = Logger(
                        log_name=f"{idx}",
                        results_dir=Path(self.results_dir) / "Logging",
                        log_console=True,
                        seed=self.seed,
                        console_log_level=30,
                    )
                    agents.append(agent)
                    loggers.append(logger)
            else:
                for idx in range(mdp.info.n_agents):
                    agent = setup_agent(
                        agent=self.agent,
                        mdp_info=mdp.info,
                        idx_agent=idx,
                        agent_params=self.agent_params,
                    )
                    logger = Logger(
                        log_name=f"{idx}",
                        results_dir=Path(self.results_dir) / "Logging",
                        log_console=True,
                        seed=self.seed,
                        console_log_level=30,
                    )
                    agents.append(agent)
                    loggers.append(logger)

            if "gru" in self.agent:
                if self.agent_params["share_agent_params"]:
                    core = MultiAgentCoreHiddenShared(agents, mdp)
                else:
                    core = MultiAgentCoreHidden(agents, mdp)
            else:
                if self.agent_params["share_agent_params"]:
                    core = MultiAgentCoreShared(agents, mdp)
                else:
                    core = MultiAgentCore(agents, mdp)
        return core, agents, loggers

    def get_n_steps_per_fit_per_agent(self, n):
        """
        Get the number of steps per fit for each agent
        """
        return (
            [self.env_run_params["n_steps_per_fit"] for _ in range(n)]
            if self.env_run_params["n_steps_per_fit"] is not None
            else None
        )

    def get_n_episodes_per_fit_per_agent(self, n):
        """
        Get the number of episodes per fit for each agent
        """
        return (
            [self.env_run_params["n_episodes_per_fit"] for _ in range(n)]
            if self.env_run_params["n_episodes_per_fit"] is not None
            else None
        )

    def provide_mdp(self):
        if self.env == "sc2":
            mdp = SMAC(**self.env_params)
        elif self.env == "mamujoco":
            mdp = Mamujoco(**self.env_params)
        elif self.env == "plug":
            mdp = Plug(**self.env_params)
        elif self.env == "attraction":
            mdp = Attraction(**self.env_params)
        return mdp

    @staticmethod
    def load_agent(agent_path):
        agent = Serializable.load(agent_path)
        return agent

    @staticmethod
    def save_agents(agents, names, results_dir, seed, full_save=False):
        for idx, agent in enumerate(agents):
            filename = "exp_" + str(seed) + "_" + str(names[idx]) + ".msh"
            agent.save(Path(results_dir) / filename, full_save=full_save)

    @staticmethod
    def save_data(results_dir, **kwargs):
        Path(results_dir).mkdir(parents=True, exist_ok=True)
        for name, data in kwargs.items():
            filename = name + ".npy"
            path = Path(results_dir) / filename

            np.save(str(path), data)
