"""
This file gives an implementation of centralized learning like maddpg, qmix. With this learner, users can easily frame single-agent algorithms as multi-agent centralized algorithms.
"""

import copy
import enum
import time
import os
import torch
import torch.nn.functional as F

from torch import optim
from gym import spaces

from expground.common.policy_pool import PolicyPool
from expground.algorithms.ddpg.config import CENTRALIZED_CRITIC_NETWORK

from expground import optimizers, settings
from expground.logger import log, monitor
from expground.types import (
    Dict,
    LearningMode,
    PolicyConfig,
    RolloutConfig,
    TrainingConfig,
    Sequence,
    AgentID,
    Any,
    Union,
)
from expground.utils import data as data_utils, rollout
from expground.utils.sampler import get_sampler, SamplerInterface
from expground.utils.stoppers import get_stopper
from expground.utils.data import EpisodeKeys, default_dtype_mapping

from expground.utils.logging import write_to_tensorboard
from expground.utils.preprocessor import get_preprocessor
from expground.common.models import get_model
from expground.algorithms import misc
from expground.algorithms.base_trainer import Trainer
from expground.algorithms.base_policy import Policy
from expground.algorithms.loss_func import LossFunc
from expground.learner.base_learner import Learner
from expground.envs.agent_interface import (
    AgentInterface,
    DEFAULT_ACTION_ADAPTER,
    DEFAULT_OBSERVATION_ADAPTER,
)


def _group_validation(
    groups: Dict[str, Sequence[AgentID]], possible_agents: Sequence[AgentID]
) -> bool:
    all_agents = []
    for group in groups.values():
        all_agents += group
    assert len(all_agents) == len(
        possible_agents
    ), f"Agents should be no overlapped and no lacking: {groups} {possible_agents}"
    return True


class ppool_loss(LossFunc):
    def __init__(self, mute_critic_loss: bool = True):
        super().__init__(mute_critic_loss=mute_critic_loss)

    def __call__(self, batch) -> Dict[str, Any]:
        if self.optimizers is not None:
            self.optimizers.step()
        return {}

    def zero_grad(self):
        if self.optimizers is not None:
            return super().zero_grad()

    def step(self) -> Any:
        pass

    def setup_optimizers(self, *args, **kwargs):
        assert isinstance(self.policy, PolicyPool), self.policy
        if len(self.policy._distribution._distribution_tensor) > 0:
            self.optimizers = optim.SGD(
                list(self.policy._distribution._distribution_tensor.values()), lr=1.0
            )
        else:
            self.optimizers = None

    def reset(self, policy: Policy, configs: Dict[str, Any]):
        self._params.update(configs or {})
        # print("\t- loss func reset policy:", policy)
        self._policy = policy
        self.optimizers = None
        self.loss = []
        self.setup_optimizers()
        self.setup_extras()


def _loss_wrap(
    policies: Dict[AgentID, Policy],
    training_config: Dict[str, Any],
    loss_func_cls: type,
) -> type:
    """Wraps multiple loss functions for multiple policies, and behaves like a general loss function.

    Args:
        policies (Dict[AgentID, Policy]): The dict of ego agents' policies, mapping from agent to policies.
        training_config (Dict[str, Any]): The dict of training configuration, some hyper parameters.
        loss_func_cls (type): The class used to Create loss function instance.

    Returns:
        type: A general loss function accepts a dict of agent batches.
    """

    # create policies' loss functions by muting the critic loss update
    _loss_funcs = {
        k: ppool_loss(mute_critic_loss=True)
        if isinstance(policy, PolicyPool) and not policy.is_fixed
        else loss_func_cls(mute_critic_loss=True)
        for k, policy in policies.items()
    }
    # for k, policy in policies.items():
    #     _loss_funcs[k].reset(policy, training_config)

    class _loss(LossFunc):
        def __init__(self):
            # mute critic loss must always be False
            mute_critic_loss = True
            super(_loss, self).__init__(mute_critic_loss=mute_critic_loss)

        def __call__(self, batch: Dict[str, Dict]) -> Dict[str, Any]:
            """Integrate loss computation for each policy.

            Args:
                batch (Dict[str, Dict]): A dict of batch, mapping from policy ids to corresponding batches.

            Returns:
                Dict[str, Any]: A dict of loss statistic, mapping from policy ids to corresponding loss feedabck.
            """

            res = {}
            for k in policies:
                # here, batch[k] could be an agent dict or pure dict
                tmp = _loss_funcs[k](batch[k])
                if len(tmp) == 0:
                    continue
                else:
                    res[k] = tmp
            return res

        def step(self) -> Any:
            _ = [loss.step() for loss in _loss_funcs.values()]

        def zero_grad(self):
            _ = [loss.zero_grad() for loss in _loss_funcs.values()]

        def setup_optimizers(self, *args, **kwargs):
            _ = [
                loss.setup_optimizers(*args, **kwargs) for loss in _loss_funcs.values()
            ]

        def reset(self, policies: Dict[AgentID, Policy], configs: Dict[str, Any]):
            for k, _loss_func in _loss_funcs.items():
                policy = policies[k]
                if isinstance(policy, PolicyPool):
                    # got active policy of it
                    # check whether it is fixed or not
                    if not policy.is_fixed:
                        _loss_func.reset(policy, configs)
                    else:
                        # check whether there are active policy, if not, raise Error
                        active_policies = policy.get_active_policies()
                        if len(active_policies) < 1:
                            log.warning("No active policies found!")
                        else:
                            # log.info("Got active policy, train it!")
                            policy = list(active_policies.values())[0]
                            _loss_func.reset(policy, configs)
                else:
                    _loss_func.reset(policy, configs)

    loss = _loss()
    loss.reset(policies, training_config)
    return loss


class CentralizedTrainer(Trainer):
    def __init__(
        self,
        loss_func_cls: type,
        training_config: Dict[str, Any],
        policies: Dict[AgentID, Union[Policy, PolicyPool]],
        model_config: Dict[str, Any] = None,
        use_cuda: bool = False,
        use_global_state: bool = False,
        ego_agents: Union[AgentID, Sequence[AgentID]] = None,
        exp_config=None,
    ):
        """Initialize a centralized trainer.

        Args:
            loss_func_cls (type): Loss function class for individual policies. Current implementation shares the loss type.
            training_config (Dict[str, Any]): A dict of training configuration.
            policies (Dict[AgentID, Union[Policy, PolicyPool]]): A dict of policies or policy pool, mapping from agent to policies/policypools.
            model_config (Dict[str, Any], optional): Model configuration, shared by all policies. Defaults to None.
            use_cuda (bool, optional): Turn on CUDA device for training or not. Defaults to False.
            use_global_state (bool, optional): Determine whether there is a global state specified in current case. Defaults to False.
            ego_agents (Union[AgentID, Sequence[AgentID]], optional): Specify ego agents. Defaults to None, means all agents will be updated in this trainer.
        """

        super(CentralizedTrainer, self).__init__(
            None,
            training_config=training_config,
            policy_instance=policies,
            exp_config=exp_config,
        )
        # then build a centralized critic with given configs
        observation_spaces = {
            aid: policy._observation_space for aid, policy in policies.items()
        }
        action_spaces = {aid: policy._action_space for aid, policy in policies.items()}
        self._model_config = model_config or CENTRALIZED_CRITIC_NETWORK

        # check ego agents in possible agents
        self._governed_agents = list(observation_spaces.keys())
        self._use_global_state = use_global_state
        # adv_eps: for minmax optimization
        self._adversary = self._training_config.get("adv_eps", 0) > 0

        if isinstance(ego_agents, AgentID):
            ego_agents = [ego_agents]

        assert all(
            item in self._governed_agents for item in ego_agents
        ), f"Ego agents (%s) should be contained in governed agents (%s)" % (
            ego_agents,
            self._governed_agents,
        )
        self._ego_agents = ego_agents or self._governed_agents
        self._loss_func = _loss_wrap(
            {aid: policies[aid] for aid in self._ego_agents},
            training_config,
            loss_func_cls,
        )

        #   abstract this process as a module in the future, so that users can apply their customized
        #   centralized critic in this centralied trainer without extra works for defining a new trainer.
        if use_global_state:
            # detect whether there is a `global_state` key in each agent observation space, if not, throw error
            for obs_space in observation_spaces.values():
                assert "global_state" in obs_space.spaces
            global_state_space = copy.deepcopy(
                list(observation_spaces.values())[0].spaces["global_state_space"]
            )
        else:
            global_state_space = spaces.Dict(**observation_spaces)

        self._observation_space = spaces.Dict(
            {
                "obs": global_state_space,
                "action": spaces.Dict(**action_spaces),
            }
        )
        self._obs_preprocessor = get_preprocessor(global_state_space)(
            global_state_space
        )
        self._preprocessor = get_preprocessor(self._observation_space)(
            self._observation_space
        )
        self._action_space = spaces.Discrete(1)
        self._centralized_critic = None
        self._target_centralized_critic = None
        self._optimizer = None
        # self._build_centralized_critics_and_optimizer()

    def _build_centralized_critics_and_optimizer(
        self, model_config: Dict[str, Any], use_cuda: bool = False
    ):
        self._centralized_critic = get_model(model_config, framework="torch")(
            self._observation_space, self._action_space, use_cuda=use_cuda
        )
        self._target_centralized_critic = get_model(model_config, framework="torch")(
            self._observation_space, self._action_space, use_cuda=use_cuda
        )
        self._target_centralized_critic.load_state_dict(
            self._centralized_critic.state_dict()
        )
        self._optimizer = getattr(torch.optim, self._training_config["optimizer"])(
            self._centralized_critic.parameters(), lr=self._training_config["critic_lr"]
        )

    @property
    def governed_agents(self):
        return self._governed_agents

    def _compute_policy_loss(
        self, agent_obs, critic_obs, actions, batch, adversary: bool = False
    ):
        # compute gradients from value
        meta_probs = None
        for i, aid in enumerate(self._governed_agents):
            if aid in self._ego_agents:
                # replace ego agent action with on-line computing
                actions[i], meta_probs = self._policy[aid].compute_actions(
                    agent_obs[aid],
                    use_target=False,
                    action_mask=batch[aid].get(EpisodeKeys.ACTION_MASK.value, None),
                    explore=True,
                )
            else:
                if adversary:
                    actions[i].requires_grad = True
        critic_in = torch.cat([critic_obs] + actions, dim=-1)
        policy_loss = -self._centralized_critic(critic_in).view(-1).mean()

        if adversary:
            adv_rate = [
                self._training_config["adv_eps"] for _ in self._governed_agents
            ]  # from outer self._adv_rate = [some adv_eps_s and some adv_eps]
            # if meta_probs is not None:
            #     print("meta probs:", torch.autograd.grad(policy_loss, meta_probs, only_inputs=True))
            #     raw_perturb = torch.autograd.grad(policy_loss, actions)
            #     input("press any key to continue:")
            raw_perturb = torch.autograd.grad(policy_loss, actions)
            # print("perturb for policy loss:", raw_perturb)
            perturb = [F.normalize(e, dim=-1, p=2).detach() for e in raw_perturb]
            perturb = [
                _perturb * _adv_rate for _perturb, _adv_rate in zip(perturb, adv_rate)
            ]
            # replace other agents actions
            new_actions = []
            for i, agent in enumerate(self._governed_agents):
                if agent not in self._ego_agents:
                    action_mask = batch[agent].get(EpisodeKeys.ACTION_MASK.value, None)
                    if action_mask is not None:
                        new_actions.append((perturb[i] + actions[i]) * action_mask)
                    else:
                        new_actions.append(perturb[i] + actions[i])
                else:
                    new_actions.append(actions[i])
            # new_actions = [
            #     (perturb[i] + actions[i]) * if aid not in self._ego_agents else actions[i]
            #     for i, aid in enumerate(self._governed_agents)
            # ]
            adv_critic_in = torch.cat([critic_obs] + new_actions, dim=-1)
            adv_value = self._centralized_critic(adv_critic_in).view(-1)
            policy_loss = -adv_value.mean()

        action_reg_loss = torch.sum(
            torch.cat(
                [
                    torch.mean(actions[i] ** 2, dim=-1, keepdim=True)
                    for i in range(len(self._governed_agents))
                ],
                dim=-1,
            ),
            dim=-1,
        )

        policy_loss += 0.001 * action_reg_loss.mean()

        return policy_loss

    def _compute_value_loss(
        self,
        critic_obs,
        actions,
        next_critic_obs,
        next_actions,
        batch,
        adversary: bool = False,
    ):
        # if adversary:
        #     _ = [e.retain_grad() for e in next_actions]

        next_critic_in = torch.cat([next_critic_obs] + next_actions, dim=-1)
        next_value = self._target_centralized_critic(next_critic_in)
        gamma = self._training_config["gamma"]

        # target_value for each agents

        critic_in = torch.cat([critic_obs] + actions, dim=-1)
        value = self._centralized_critic(critic_in).view(-1)

        if adversary:
            adv_rate = [self._training_config["adv_eps"] for _ in self._governed_agents]
            pg_loss = -torch.mean(next_value)
            # pg_loss.backward()
            # get gradients on the logits action as perturb
            # raw_perturb = [action.grad for action in next_actions]
            raw_perturb = torch.autograd.grad(pg_loss, next_actions)
            perturb = [
                adv_rate[i] * F.normalize(e, dim=-1, p=2).detach()
                for i, e in enumerate(raw_perturb)
            ]
            # replace other agents actions
            new_next_actions = []
            for i, agent in enumerate(self.governed_agents):
                if agent not in self._ego_agents:
                    next_action_mask = batch[agent].get(
                        EpisodeKeys.NEXT_ACTION_MASK.value, None
                    )
                    if next_action_mask is not None:
                        new_next_action = (
                            next_actions[i] + perturb[i]
                        ) * next_action_mask
                        new_next_actions.append(new_next_action)
                    else:
                        new_next_actions.append(next_actions[i] + perturb[i])
                else:
                    new_next_actions.append(next_actions[i])

            # new_next_actions = [
            #     next_actions[i] + perturb[i]
            #     if agent not in self._ego_agents
            #     else actions[i]
            #     for i, agent in enumerate(self.governed_agents)
            # ]
            # print("next_action:", new_next_actions)
            # input("-----")
            adv_critic_in = torch.cat([critic_obs] + new_next_actions, dim=-1)
            next_value = self._target_centralized_critic(adv_critic_in)

        target_value = torch.mean(
            torch.cat(
                [
                    batch[aid][EpisodeKeys.REWARD.value].view((-1, 1))
                    + gamma
                    * (1.0 - batch[aid][EpisodeKeys.DONE.value].view((-1, 1)))
                    * next_value
                    for aid in self._ego_agents
                ],
                dim=-1,
            ),
            dim=-1,
        )
        value_loss = F.mse_loss(value, target_value.detach())
        return value_loss, target_value, value

    def __call__(
        self, sampler: SamplerInterface, time_step: int, agent_filter: Sequence = None
    ):
        # assert ego agents in agent filters
        agent_filter = agent_filter or self._governed_agents
        return super(CentralizedTrainer, self).__call__(
            sampler, time_step, agent_filter=agent_filter
        )

    def _before_loss(self, policy: Dict[AgentID, Policy], batch: Dict[AgentID, Dict]):
        caster = (
            lambda x: torch.FloatTensor(x.copy()).to(
                device="cpu", dtype=default_dtype_mapping(x.dtype)
            )
            if not isinstance(x, torch.Tensor)
            else x
        )
        batch = data_utils.walk(caster, batch)
        cliprange = self._training_config["grad_norm_clipping"]

        # compute centralized value loss
        agent_next_obs = {
            aid: batch[aid][EpisodeKeys.NEXT_OBSERVATION.value]
            for aid in self._governed_agents
        }
        agent_obs = {
            aid: batch[aid][EpisodeKeys.OBSERVATION.value]
            for aid in self._governed_agents
        }
        if self._use_global_state:
            global_state = list(batch.values())[0][EpisodeKeys.GLOBAL_STATE.value]
            global_next_state = list(batch.values())[0][
                EpisodeKeys.NEXT_GLOBAL_STATE.value
            ]
            critic_obs = global_state
            next_critic_obs = global_next_state
        else:
            # global_state = agent_obs
            # global_next_state = agent_next_obs
            critic_obs = torch.cat(
                [agent_obs[aid] for aid in self._governed_agents], dim=-1
            )
            next_critic_obs = torch.cat(
                [agent_next_obs[aid] for aid in self._governed_agents], dim=-1
            )
        actions = [
            batch[aid][EpisodeKeys.ACTION_DIST.value] for aid in self._governed_agents
        ]
        next_actions = []
        for aid in self._governed_agents:
            a, _ = self._policy[aid].compute_actions(
                agent_next_obs[aid],
                use_target=True,
                action_mask=batch[aid].get(EpisodeKeys.NEXT_ACTION_MASK.value, None),
                explore=True,
            )
            next_actions.append(a)

        # ======================== update critic ======================
        self._optimizer.zero_grad()
        value_loss, target_value, value = self._compute_value_loss(
            critic_obs,
            actions,
            next_critic_obs,
            next_actions,
            batch,
            adversary=self._adversary,
        )
        value_loss.backward()
        torch.nn.utils.clip_grad_norm_(self._centralized_critic.parameters(), cliprange)
        self._optimizer.step()
        # =============================================================

        # ======================== update actor =======================
        self._loss_func.zero_grad()
        actions = [
            batch[aid][EpisodeKeys.ACTION_DIST.value] for aid in self._governed_agents
        ]
        policy_loss = self._compute_policy_loss(
            agent_obs, critic_obs, actions, batch, adversary=self._adversary
        )
        policy_loss.backward()
        # the optimizer step leave to loss func
        # =============================================================

        training_info = {
            "value_loss": value_loss.detach().item(),
            "target_value_mean": target_value.mean().item(),
            "target_value_max": target_value.max().item(),
            "target_value_min": target_value.min().item(),
            "eval_value_mean": value.mean().item(),
            "eval_value_max": value.max().item(),
            "eval_value_min": value.min().item(),
            "policy_loss": policy_loss.detach().item(),
        }
        # log.info("training info: %s", training_info)
        return {aid: batch for aid in self._ego_agents}, training_info

    def _after_loss(self, policy, step_counter: int):
        # sync with update
        misc.soft_update(
            self._target_centralized_critic,
            self._centralized_critic,
            tau=self._training_config["tau"],
        )
        for policy in self._policy.values():
            policy.update_target(tau=self._training_config["tau"])

    def reset(self, policy_instance, configs: Dict = None):
        self._step_counter = 0
        # if policy_instance is not self._policy:
        self._policy = policy_instance
        self._build_centralized_critics_and_optimizer(self._model_config)
        self._loss_func.reset(self._policy, configs or self._training_config)


class CentralizedLearner(Learner):
    def __init__(
        self,
        policy_config: PolicyConfig,
        env_description: Dict[str, Any],
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        centralized_critic_config: Dict[str, Any],
        learning_mode: str,
        groups: Dict[str, Sequence[AgentID]] = None,
        share_critic: bool = False,
        episodic_training: bool = True,
        train_every: int = 1,
        ego_agents: Sequence[AgentID] = None,
        enable_policy_pool: bool = False,
        experiment: str = None,
        seed: int = None,
        **kwargs,
    ):
        super(CentralizedLearner, self).__init__(
            experiment=experiment
            or f"Centralized_{policy_config.policy}_{time.time()}",
            summary_writer=kwargs.get("summary_writer", None),
            seed=seed,
            evaluation_worker_num=kwargs.get("evaluation_worker_num", 0),
        )

        possible_agents = env_description["config"]["possible_agents"]
        if groups is None:
            groups = {"default": possible_agents}
        _group_validation(groups, possible_agents)

        # check whether the group is legal
        self._env_desc = env_description
        self._policy_config = policy_config
        self._agents = possible_agents
        self._rollout_config = rollout_config
        self._share_critic = share_critic
        self._groups = groups

        self._sampler = None
        self._stopper = None
        self._episodic_training = episodic_training
        self._train_every = train_every
        self._learning_mode = learning_mode
        self._agent_interfaces = {}

        self._total_timesteps = 0
        self._total_episodes = 0

        self._use_learnable_dist = kwargs.get("use_learnable_dist", False)
        self._turn_off_logging = kwargs.get("turn_off_logging", False)

        # build trainer for ...
        self._policy_mapping_fn = lambda aid: aid
        self._policies = {}
        self._ego_agents = ego_agents or possible_agents

        for agent in self._agents:
            pid = self._policy_mapping_fn(agent)
            if pid not in self._policies:
                if enable_policy_pool:
                    self._policies[pid] = PolicyPool(
                        agent,
                        policy_config=policy_config.copy(agent),
                        start_fixed_support_num=0
                        if agent not in ego_agents
                        else kwargs.get("start_fixed_support_num", 0),
                        start_active_support_num=0
                        if agent not in ego_agents
                        else kwargs.get("start_active_support_num", 0),
                        is_fixed=True
                        if agent in ego_agents
                        else not kwargs.get("use_learnable_dist", False),
                        mixed_at_every_step=kwargs.get("mixed_at_every_step", True),
                        distribution_training_kwargs=kwargs.get(
                            "distribution_training_kwargs", {}
                        ),
                    )
                else:
                    self._policies[pid] = policy_config.new_policy_instance(agent)

        #   it will work in two modes:
        #       1) fully centralized for a group of agents, then the dict of trainer will be a mapping from
        #           groups to trainers;
        #       2) decentralized centralized for a group of agents like MADDPG, then the dict of trainer
        #           will be a mapping from agents to trainers.
        self._trainer = {}
        if share_critic:
            log.debug("Construct centralized critic in sharing mode.")
            for g, agents in groups.items():
                selected_policies = {aid: self._policies[aid] for aid in agents}
                self._trainer[g] = CentralizedTrainer(
                    loss_func,
                    training_config.hyper_params,
                    selected_policies,
                    model_config=centralized_critic_config,
                )
        else:
            log.debug("Construct centralized critic in decentralized way.")
            for g, agents in groups.items():
                selected_policies: Dict[AgentID, Union[Policy, PolicyPool]] = {
                    aid: self._policies[aid] for aid in agents
                }
                # build trainer in agent-wise
                for agent in agents:
                    if agent not in self._ego_agents and not self._use_learnable_dist:
                        continue
                    self._trainer[agent] = CentralizedTrainer(
                        loss_func,
                        training_config.hyper_params,
                        selected_policies,
                        model_config=centralized_critic_config,
                        ego_agents=agent,
                    )

        env_config = self._env_desc["config"]
        observation_spaces = env_config["observation_spaces"]
        action_spaces = env_config["action_spaces"]
        observation_adapter = env_config.get(
            "observation_adapter", DEFAULT_OBSERVATION_ADAPTER
        )
        action_adapter = env_config.get("action_adapter", DEFAULT_ACTION_ADAPTER)
        for _aid in env_config["possible_agents"]:
            self._agent_interfaces[_aid] = AgentInterface(
                policy_name="",
                policy=self._policies[_aid],
                observation_space=observation_spaces[_aid],
                action_space=action_spaces[_aid],
                observation_adapter=observation_adapter,
                action_adapter=action_adapter,
                is_active=_aid in self._ego_agents,
            )

    def get_dist(self):
        res = {}
        for aid in self._agents:
            if aid not in self._ego_agents:
                policy = self._policies[aid]
                assert isinstance(policy, PolicyPool)
                res[aid] = policy._distribution.dict_values()
        return res

    # @monitor(enable_returns=True, enable_time=True)
    def train(self, **kwargs) -> Dict[str, Dict[str, Any]]:
        """The main logics of training policies, which is componented with two stages, i.e. Rollout and Training.

        Returns:
            Dict[str, Dict[str, Any]]: A dict of dict of training feedbacks.
        """

        generator = self._rollout_config.caller(
            sampler=self._sampler,
            agent_policy_mapping={},
            agent_interfaces=self._agent_interfaces,
            env_description=self._env_desc,
            fragment_length=self._rollout_config.fragment_length,
            max_step=self._rollout_config.max_step,
            episodic=self._episodic_training,
            train_every=self._train_every,
            evaluate=False,
        )

        epoch_training_statistics = []
        try:
            start_timesteps = self._total_timesteps
            while True:
                info = next(generator)
                tmp = {}
                for tid, _trainer in self._trainer.items():
                    tmp[tid] = _trainer(self._sampler, time_step=start_timesteps)
                start_timesteps += info["timesteps"]
                epoch_training_statistics.append(tmp)
        except StopIteration as e:
            info = e.value
            self._total_timesteps += info["total_timesteps"]
            self._total_episodes += info["num_episode"]

        if self._sampler.is_ready() and self._episodic_training:
            epoch_training_statistics = [
                {
                    _id: _trainer(self._sampler, time_step=self._total_timesteps)
                    for _id, _trainer in self._trainer.items()
                }
            ]

        return epoch_training_statistics

    def learn(self, sampler_config: Dict[str, Any], stop_conditions: Dict[str, Any]):
        """The main loop of centralized learning. It will execute training and evaluation periodically.

        Args:
            sampler_config (Dict[str, Any]): The configuration to build a sampler.
            stop_conditions (Dict[str, Any]): The dict of configuration to build a stopper.
        """

        self._total_timesteps = 0
        self._total_episodes = 0

        self._sampler = get_sampler(self._agents, sampler_config)
        self._stopper = get_stopper(stop_conditions)

        self._stopper.reset()
        epoch_evaluation_statistic = self.evaluation(
            policy_mappings=None,
            max_step=self._rollout_config.max_step,
            fragment_length=min(20, self._rollout_config.num_simulation)
            * self._rollout_config.max_step,
        )
        write_to_tensorboard(
            self.summary_writer,
            epoch_evaluation_statistic,
            global_step=self._total_timesteps,
            prefix="evaluation",
        )

        if self._share_critic:
            for group, trainer in self._trainer.items():
                #   filter policies with group
                policies = {
                    agent: self._policies[agent] for agent in self._groups[group]
                }
                trainer.reset(policies, None)
        else:
            for agent, trainer in self._trainer.items():
                trainer.reset(self._policies, None)

        while not self._stopper.is_terminal():
            epoch_training_statistics = self.train()
            for x in epoch_training_statistics:
                write_to_tensorboard(
                    self.summary_writer,
                    x,
                    global_step=self._total_timesteps,
                    prefix="training",
                )
                self._stopper.step(
                    None,
                    x,
                    time_step=self._total_timesteps,
                    episode_th=self._total_episodes,
                )

            epoch_evaluation_statistic = self.evaluation(
                policy_mappings=None,
                max_step=self._rollout_config.max_step,
                fragment_length=min(20, self._rollout_config.num_simulation)
                * self._rollout_config.max_step,
            )
            write_to_tensorboard(
                self.summary_writer,
                epoch_evaluation_statistic,
                global_step=self._total_timesteps,
                prefix="evaluation",
            )

            if self._learning_mode == LearningMode.ON_POLICY:
                self._sampler.clean()

        return {"global_step": self._stopper.counter, "final": None}

    def save(self, data_dir: str = None):
        """Save model and buffer"""
        if data_dir is None:
            model_data_dir = self._exp_config.get_path("models")
        else:
            model_data_dir = os.path.join(data_dir, "models")
            if not os.path.exists(model_data_dir):
                os.makedirs(model_data_dir)

        for _aid, policy in self._policies.items():
            path = os.path.join(model_data_dir, f"{_aid}_{self.stopper.counter}")
            policy.save(path, global_step=self.stopper.counter)

    def load(self, data_dir: str = None, global_step=0):
        """Load model and dataset from local backup"""
        if data_dir is None:
            model_data_dir = self._exp_config.get_path("models")
        else:
            model_data_dir = os.path.join(data_dir, "models")
            if not os.path.exists(model_data_dir):
                os.makedirs(model_data_dir)

        model_data_dir = self._exp_config.get_path("models")
        for _aid, policy in self._policies.items():
            path = os.path.join(model_data_dir, f"{_aid}_{global_step}")
            policy.load(path)

        self._global_iteration = global_step
