"""
This file gives an implementation of centralized learning like maddpg, qmix. With this learner, users can easily frame single-agent algorithms as multi-agent centralized algorithms.
"""

import time
import os


from expground.common.policy_pool import PolicyPool


from expground import optimizers, settings
from expground.logger import Log
from expground.types import (
    List,
    Dict,
    LearningMode,
    PolicyConfig,
    PolicyID,
    RolloutConfig,
    TrainingConfig,
    Sequence,
    AgentID,
    Any,
    Union,
)
from expground.utils.sampler import get_sampler
from expground.utils.stoppers import get_stopper

from expground.utils.logging import write_to_tensorboard

from expground.algorithms.base_policy import Policy
from expground.learner.base_learner import Learner
from expground.envs.agent_interface import (
    AgentInterface,
    DEFAULT_ACTION_ADAPTER,
    DEFAULT_OBSERVATION_ADAPTER,
)

from .utils import _group_validation
from .normal_critic import CentralizedTrainer as NormalCentralizedTrainer

# from .logits_critic import CentralizedTrainer as LogitCentralizedTrainer


POLICY_NAME = lambda x: "{}_policy".format(x)


class CentralizedLearner(Learner):
    def __init__(
        self,
        policy_config: PolicyConfig,
        env_description: Dict[str, Any],
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = True,
        train_every: int = 1,
        ego_agents: Sequence[AgentID] = None,
        experiment: str = None,
        seed: int = None,
        # =============================
        ray_mode: bool = False,
        enable_policy_pool: bool = False,
        summary_writer: object = None,
        turn_off_logging: bool = False,
        policy_pool_config: Dict[str, Any] = None,
        resource_config: Dict[str, Any] = None,
        custom_config: Dict[str, Any] = None
        # ============================
        # share_critic: bool = False,
        # groups: Dict[str, Sequence[AgentID]] = None,
    ):
        super(CentralizedLearner, self).__init__(
            experiment=experiment
            or f"Centralized_{policy_config.policy}_{time.time()}",
            summary_writer=summary_writer,
            seed=seed,
            ray_mode=ray_mode,
            enable_policy_pool=enable_policy_pool,
            resource_config=resource_config,
            policy_pool_config=policy_pool_config,
        )

        possible_agents = env_description["config"]["possible_agents"]
        # critic will be shared by agents in a same group is share_critic is True
        if custom_config.get("groups", None) is None:
            groups = {"default": possible_agents}
            custom_config["groups"] = groups
        _group_validation(custom_config["groups"], possible_agents)
        groups = custom_config["groups"]
        trainer_type = custom_config.get(
            "centralized_trainer_type", "normal"
        )  # normal or logits

        # check whether the group is legal
        self._env_desc = env_description
        self._policy_config = policy_config
        self._agents = possible_agents
        self._rollout_config = rollout_config
        self._share_critic = custom_config.get("share_critic", False)
        self._groups = groups

        self._sampler = None
        self._stopper = None
        self._episodic_training = episodic_training
        self._train_every = train_every
        self._learning_mode = learning_mode
        self._agent_interfaces = {}

        self._total_timesteps = 0
        self._total_episodes = 0

        self._turn_off_logging = turn_off_logging

        self._policies = {}
        self._ego_agents = ego_agents or possible_agents

        for agent in self._agents:
            if enable_policy_pool:
                # ego agents accept support num configuration from extra parameters
                self._policies[agent] = PolicyPool(
                    agent,
                    policy_config=policy_config.copy(agent),
                    start_fixed_support_num=0
                    if agent not in ego_agents
                    else policy_pool_config.get("start_fixed_support_num", 0),
                    start_active_support_num=0
                    if agent not in ego_agents
                    else policy_pool_config.get("start_active_support_num", 0),
                    # non-ego agents can be set to use learnable dist or not
                    is_fixed=True
                    if agent in ego_agents
                    else not policy_pool_config.get("use_learnable_dist", False),
                    mixed_at_every_step=policy_pool_config.get(
                        "mixed_at_every_step", False
                    ),
                    # if `is_fixed` is False, then distribution_training_kwargs cannot be None
                    distribution_training_kwargs=policy_pool_config.get(
                        "distribution_training_kwargs", {}
                    ),
                )
            else:
                self._policies[agent] = policy_config.new_policy_instance(agent)
        self._trainer = {}
        self._training_config = training_config

        # chose centralized trainer
        trainer_cls = NormalCentralizedTrainer  # if trainer_type == "normal" else LogitCentralizedTrainer
        if custom_config.get("share_critic", False):
            Log.debug("Construct centralized critic in sharing mode.")
            for g, agents in groups.items():
                selected_policies = {aid: self._policies[aid] for aid in agents}
                self._trainer[g] = trainer_cls(
                    loss_func,
                    training_config.hyper_params,
                    selected_policies,
                    model_config=custom_config["centralized_critic_config"],
                    ego_agents=agents,
                    mode="minmax",
                )
        else:
            Log.debug("Construct centralized critic in decentralized way.")
            for g, agents in groups.items():
                selected_policies: Dict[AgentID, Union[Policy, PolicyPool]] = {
                    aid: self._policies[aid] for aid in agents
                }
                # build trainer in agent-wise
                for agent in agents:
                    if agent not in self._ego_agents and not self._use_learnable_dist:
                        continue
                    self._trainer[agent] = trainer_cls(
                        loss_func,
                        training_config.hyper_params,
                        selected_policies,
                        model_config=custom_config["centralized_critic_config"],
                        ego_agents=agent,
                    )

        env_config = self._env_desc["config"]
        observation_spaces = env_config["observation_spaces"]
        action_spaces = env_config["action_spaces"]
        observation_adapter = env_config.get(
            "observation_adapter", DEFAULT_OBSERVATION_ADAPTER
        )
        action_adapter = env_config.get("action_adapter", DEFAULT_ACTION_ADAPTER)
        for _aid in env_config["possible_agents"]:
            self._agent_interfaces[_aid] = AgentInterface(
                policy_name="",
                policy=self._policies[_aid],
                observation_space=observation_spaces[_aid],
                action_space=action_spaces[_aid],
                observation_adapter=observation_adapter,
                action_adapter=action_adapter,
                is_active=_aid in self._ego_agents,
            )

    def get_dist(self):
        res = {}
        for aid in self._agents:
            if aid not in self._ego_agents:
                policy = self._policies[aid]
                assert isinstance(policy, PolicyPool)
                res[aid] = policy._distribution.dict_values()
        return res

    # @monitor(enable_returns=True, enable_time=True)
    def train(self, **kwargs) -> Dict[str, Dict[str, Any]]:
        """The main logics of training policies, which is componented with two stages, i.e. Rollout and Training.

        Returns:
            Dict[str, Dict[str, Any]]: A dict of dict of training feedbacks.
        """

        generator = self._rollout_config.caller(
            sampler=self._sampler,
            agent_policy_mapping={},
            agent_interfaces=self._agent_interfaces,
            env_description=self._env_desc,
            fragment_length=self._rollout_config.fragment_length,
            max_step=self._rollout_config.max_step,
            episodic=self._episodic_training,
            train_every=self._train_every,
            evaluate=False,
        )

        epoch_training_statistics = []
        try:
            start_timesteps = self._total_timesteps
            while True:
                info = next(generator)
                tmp = {}
                if self._sampler.is_ready():
                    for tid, _trainer in self._trainer.items():
                        tmp[tid] = _trainer(
                            self._sampler,
                            time_step=start_timesteps,
                            stage="critic",
                            n_inner_loop=1,
                        )
                    for tid, _trainer in self._trainer.items():
                        tmp[tid].update(
                            _trainer(
                                self._sampler,
                                time_step=start_timesteps,
                                stage="actor",
                                n_inner_loop=1,
                            )
                        )
                    start_timesteps += info["timesteps"]
                    epoch_training_statistics.append(tmp)
                    # update all policy targets after training
                    for pid, policy in self._policies.items():
                        policy.update_target(
                            tau=self._training_config.hyper_params["tau"]
                        )
                        # if not policy.is_fixed:
                        #     print("----- policy dist:", pid, policy._distribution.tensor, policy._distribution.dict_values())
        except StopIteration as e:
            info = e.value
            self._total_timesteps += info["total_timesteps"]
            self._total_episodes += info["num_episode"]

        if self._sampler.is_ready() and self._episodic_training:
            epoch_training_statistics = [
                {
                    _id: _trainer(self._sampler, time_step=self._total_timesteps)
                    for _id, _trainer in self._trainer.items()
                }
            ]
            # update all policy targets after training
            for policy in self._policy.values():
                policy.update_target(tau=self._training_config.hyper_params["tau"])

        return epoch_training_statistics

    def remove_dominated_policies(self, pids_dict: Dict[AgentID, List[PolicyID]]):
        for aid, pids in pids_dict.items():
            if aid in self.ego_agents:
                policy = self._policies[aid]
                assert isinstance(policy, PolicyPool)
                policy.remove_dominated_policies(pids)

    def return_dominated_policies(self) -> Dict[AgentID, List[PolicyID]]:
        policy_ids: Dict[AgentID, List[PolicyID]] = {}

        for agent in self.agents:
            if agent not in self.ego_agents:
                policy = self._policies[agent]
                assert isinstance(policy, PolicyPool)
                ids = policy.return_dominated_policies()
                policy_ids[agent] = ids

        return policy_ids

    def learn(self, sampler_config: Dict[str, Any], stop_conditions: Dict[str, Any]):
        """The main loop of centralized learning. It will execute training and evaluation periodically.

        Args:
            sampler_config (Dict[str, Any]): The configuration to build a sampler.
            stop_conditions (Dict[str, Any]): The dict of configuration to build a stopper.
        """

        self.reset()

        self._sampler = get_sampler(self._agents, sampler_config)
        self._stopper = get_stopper(stop_conditions)
        self._stopper.reset()

        for agent, trainer in self._trainer.items():
            trainer.reset(None)

        epoch_evaluation_statistic = self.evaluation(
            policy_mappings=None,
            max_step=self._rollout_config.max_step,
            fragment_length=min(20, self._rollout_config.num_simulation)
            * self._rollout_config.max_step,
        )[0][1]
        write_to_tensorboard(
            self.summary_writer,
            epoch_evaluation_statistic,
            global_step=self._total_timesteps,
            prefix="evaluation",
        )

        while not self._stopper.is_terminal():
            epoch_training_statistics = self.train()
            for x in epoch_training_statistics:
                write_to_tensorboard(
                    self.summary_writer,
                    x,
                    global_step=self._total_timesteps,
                    prefix="training",
                )
                self._stopper.step(
                    None,
                    x,
                    time_step=self._total_timesteps,
                    episode_th=self._total_episodes,
                )

            epoch_evaluation_statistic = self.evaluation(
                policy_mappings=None,
                max_step=self._rollout_config.max_step,
                fragment_length=min(20, self._rollout_config.num_simulation)
                * self._rollout_config.max_step,
            )[0][1]
            write_to_tensorboard(
                self.summary_writer,
                epoch_evaluation_statistic,
                global_step=self._total_timesteps,
                prefix="evaluation",
            )

            if self._learning_mode == LearningMode.ON_POLICY:
                self._sampler.clean()

        if "win_rate" in epoch_evaluation_statistic:
            ret = {"FinalWinRate": epoch_evaluation_statistic["win_rate"]}
        elif "reward" in epoch_evaluation_statistic:
            ret = {"FinalReward": epoch_evaluation_statistic["reward"]}
        else:
            ret = None

        return ret

    def save(self, data_dir: str = None):
        self._global_iteration = self._stopper.counter
        data_dir = data_dir or self.state_dir
        exp_data_dir = os.path.join(data_dir, settings.DATA_SUB_DIR_NAME)
        if not os.path.exists(exp_data_dir):
            os.makedirs(exp_data_dir)

        model_data_dir = os.path.join(data_dir, settings.MODEL_SUB_DIR_NAME)
        if not os.path.exists(model_data_dir):
            os.makedirs(model_data_dir)
        for _aid, policy in self._policies.items():
            path = os.path.join(model_data_dir, f"{_aid}_{self._global_iteration}")
            policy.save(path, global_step=self._global_iteration)

    def load(self, data_dir: str, global_step=0):
        """Load model and dataset from local backup"""

        model_data_dir = os.path.join(data_dir, settings.MODEL_SUB_DIR_NAME)
        for _aid, policy in self._policies.items():
            path = os.path.join(model_data_dir, f"{_aid}_{global_step}")
            policy.load(path)

        self._global_iteration = global_step


__all__ = ["CentralizedLearner"]
