import copy
import torch
from torch._C import Value
import torch.nn.functional as F

from torch import optim
from torch.distributions import Categorical
from gym import spaces

from expground.types import Sequence, AgentID, Dict, Any, Union, List
from expground.utils.data import EpisodeKeys, default_dtype_mapping
from expground.utils.preprocessor import get_preprocessor
from expground.utils import data as data_utils
from expground.utils.sampler import SamplerInterface

from expground.algorithms import misc
from expground.algorithms.ddpg.config import CENTRALIZED_CRITIC_NETWORK
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy
from expground.algorithms.base_trainer import Trainer

from expground.common.policy_pool import PolicyPool
from expground.common.models import get_model

from .utils import _loss_wrap, ppool_loss


class CentralizedTrainer(Trainer):
    def __init__(
        self,
        loss_func_cls: type,
        training_config: Dict[str, Any],
        policies: Dict[AgentID, Union[Policy, PolicyPool]],
        model_config: Dict[str, Any] = None,
        use_global_state: bool = False,
        ego_agents: Union[AgentID, Sequence[AgentID]] = None,
        mode: str = "independent",  # independent, cooperative, minmax
    ):
        """Initialize a centralized trainer.

        Note:
            Support only homogeneous loss function class for now.

        Args:
            loss_func_cls (type): Loss function class for individual policies. Current implementation shares the loss type.
            training_config (Dict[str, Any]): A dict of training configuration.
            policies (Dict[AgentID, Union[Policy, PolicyPool]]): A dict of policies or policy pool, mapping from agent to policies/policypools.
            model_config (Dict[str, Any], optional): Model configuration for centralized critic, shared by all policies. Defaults to None.
            use_global_state (bool, optional): Determine whether there is a global state specified in current case. Defaults to False.
            ego_agents (Union[AgentID, Sequence[AgentID]], optional): Specify ego agents. Defaults to None, means all agents will be updated in this trainer, i.e., fully centralized critic. only work for cooperative games
        """

        # then build a centralized critic with given configs
        observation_spaces = {
            aid: policy._observation_space for aid, policy in policies.items()
        }
        action_spaces = {aid: policy._action_space for aid, policy in policies.items()}
        self._model_config = model_config or CENTRALIZED_CRITIC_NETWORK

        # check ego agents in possible agents
        self._involved_agents = list(observation_spaces.keys())
        self._use_global_state = use_global_state
        self._mode = mode

        if isinstance(ego_agents, AgentID):
            ego_agents = [ego_agents]

        if mode == "minmax":
            assert len(ego_agents) == 2, ego_agents
        elif mode == "independent":
            assert (
                len(ego_agents) == 1
            ), "Independent mode support only one ego_agent for each trainer, while ego_agents={}".format(
                ego_agents
            )
        else:
            raise ValueError("Not support other mode except `independentt")

        assert all(
            item in self._involved_agents for item in ego_agents
        ), f"Ego agents (%s) should be a subset of governed agents (%s)" % (
            ego_agents,
            self._involved_agents,
        )

        self._ego_agents = ego_agents or self._involved_agents

        # collect loss functions of all ego agents, and wrap them as a whole
        self.agent_loss_func: LossFunc = _loss_wrap(
            {aid: policies[aid] for aid in self._ego_agents},
            training_config,
            loss_func_cls,
        )

        if use_global_state:
            # detect whether there is a `global_state` key in each agent observation space, if not, throw error
            for agent, obs_space in observation_spaces.items():
                assert isinstance(
                    obs_space, spaces.Dict
                ), "Agent observation space should be `gym.spaces.Dict` when `global_state` mode is ON, while {} received for agent={}".format(
                    type(obs_space), agent
                )
                assert (
                    "global_state" in obs_space.spaces
                ), "`global_state` mode is on, but no key `global_state` can be found in the agent observation space keys."
            global_state_space = copy.deepcopy(
                list(observation_spaces.values())[0].spaces["global_state_space"]
            )
        else:
            global_state_space = spaces.Dict(**observation_spaces)

        self._observation_space = spaces.Dict(
            {
                "obs": global_state_space,
                "action": spaces.Dict(**action_spaces),
            }
        )

        self._state_preprocessor = get_preprocessor(global_state_space)(
            global_state_space
        )
        self._obs_preprocessor = get_preprocessor(self._observation_space)(
            self._observation_space
        )
        self._action_space = spaces.Discrete(1)
        self._centralized_critic = None
        self._target_centralized_critic = None
        self._critic_optimizer = None
        self._training_stage = None

        super(CentralizedTrainer, self).__init__(
            self.agent_loss_func,
            training_config=training_config,
            policy_instance=policies,
        )

        # adv_eps: for minmax optimization
        self._adversary = self._training_config.get("adv_eps", 0.0) > 0.0

    def _build_centralized_critics_and_optimizer(self):
        """Build centralized critics and optimizer for each ego agent respectively."""

        model_config = self._model_config
        use_cuda = False
        framework = "torch"

        self._centralized_critic = get_model(model_config, framework=framework)(
            self._observation_space, self._action_space, use_cuda=use_cuda
        )
        self._target_centralized_critic = get_model(model_config, framework=framework)(
            self._observation_space, self._action_space, use_cuda=use_cuda
        )
        self._target_centralized_critic.load_state_dict(
            self._centralized_critic.state_dict()
        )
        self._critic_optimizer: torch.optim.Optimizer = getattr(
            torch.optim, self._training_config["optimizer"]
        )(self._centralized_critic.parameters(), lr=self._training_config["critic_lr"])

    @property
    def involved_agents(self) -> List[AgentID]:
        """Return a list of involved environment agent ids.

        Returns:
            List[AgentID]: A list of agent ids.
        """

        return self._involved_agents

    def _compute_policy_loss(
        self,
        agent_obs: List[torch.Tensor],
        critic_obs: List[torch.Tensor],
        action_dists: List[torch.Tensor],
        batch: Dict[AgentID, Dict[str, torch.Tensor]],
        adversary: bool = False,
    ):
        """Compute policy loss

        Args:
            agent_obs (Dict[AgentID, torch.Tensor]): A dict of agent observation Tensor.
            critic_obs (Dict[AgentID, torch.Tensor]): A dict of critic state Tensors.
            actions (Dict[AgentID, torch.Tensor]): A dict of agent actions.
            batch (Dict[AgentID, Dict[str, torch.Tensor]]): A dict of agent batches.
            adversary (bool, optional): Considering minmax optimization or not. Defaults to False.

        Returns:
            [type]: [description]
        """

        # compute gradients from value
        for i, aid in enumerate(self._involved_agents):
            if aid in self._ego_agents:
                # replace ego agent action with on-line computing
                # we treat logits as actions here
                actions, action_dists[i], _ = self._policy[aid].compute_actions(
                    agent_obs[aid],
                    use_target=False,
                    action_mask=batch[aid].get(EpisodeKeys.ACTION_MASK.value, None),
                    explore=False,
                )
            else:
                if adversary:
                    # for adversary
                    action_dists[i].requires_grad = True
                else:
                    action_dists[i] = action_dists[i].detach()
        critic_in = torch.cat([critic_obs.detach()] + action_dists, dim=-1)
        policy_loss = -self._centralized_critic(critic_in).view(-1).mean()

        if adversary:
            adv_rate = [
                self._training_config["adv_eps"] for _ in self._involved_agents
            ]  # from outer self._adv_rate = [some adv_eps_s and some adv_eps]
            raw_perturb = torch.autograd.grad(
                policy_loss, action_dists, retain_graph=False, create_graph=False
            )
            perturb = [F.normalize(e, dim=-1, p=2).detach() for e in raw_perturb]
            perturb = [
                _perturb * _adv_rate for _perturb, _adv_rate in zip(perturb, adv_rate)
            ]
            # replace other agents actions
            new_logits = []
            for i, agent in enumerate(self.involved_agents):
                if agent not in self._ego_agents:
                    # action_mask = batch[agent].get(EpisodeKeys.ACTION_MASK.value, None)
                    # if action_mask is not None:
                    #     tmp = misc.masked_logits(
                    #         perturb[i] + actions[i], mask=action_mask, normalize=True
                    #     )  # (perturb[i] + actions[i]) * action_mask
                    # els
                    tmp = perturb[i] + action_dists[i]
                    tmp = tmp.detach()
                    new_logits.append(tmp)
                else:
                    new_logits.append(action_dists[i])
            adv_critic_in = torch.cat([critic_obs] + new_logits, dim=-1)
            adv_value = self._centralized_critic(adv_critic_in).view(-1)
            policy_loss = -adv_value.mean()

        if self._mode == "independent":
            action_reg_loss = torch.mean(actions ** 2)
        elif self._mode == "cooperative":
            action_reg_loss = torch.sum(
                torch.cat(
                    [
                        torch.mean(actions ** 2, dim=-1, keepdim=True)
                        for i in range(len(self._involved_agents))
                    ],
                    dim=-1,
                ),
                dim=-1,
            ).mean()
        else:
            raise ValueError(
                "Not implemented current training mode: {}".format(self._mode)
            )

        policy_loss += 0.001 * action_reg_loss

        return policy_loss

    def _compute_value_loss(
        self,
        critic_obs,
        actions,
        next_critic_obs,
        next_actions,
        batch,
        adversary: bool = False,
    ):
        next_critic_in = torch.cat([next_critic_obs] + next_actions, dim=-1)
        next_value = self._target_centralized_critic(next_critic_in).view(-1)
        gamma = self._training_config["gamma"]

        # target_value for each agents

        critic_in = torch.cat([critic_obs] + actions, dim=-1)
        value = self._centralized_critic(critic_in).view(-1)

        if adversary:
            adv_rate = [self._training_config["adv_eps"] for _ in self._involved_agents]
            pg_loss = -torch.mean(next_value)
            # pg_loss.backward()
            # get gradients on the logits action as perturb
            # raw_perturb = [action.grad for action in next_actions]
            raw_perturb = torch.autograd.grad(
                pg_loss, next_actions, retain_graph=False, create_graph=False
            )
            perturb = [
                adv_rate[i] * F.normalize(e, dim=-1, p=2).detach()
                for i, e in enumerate(raw_perturb)
            ]
            # replace other agents actions
            new_next_actions = []
            for i, agent in enumerate(self.involved_agents):
                if agent not in self._ego_agents:
                    # next_action_mask = batch[agent].get(
                    #     EpisodeKeys.NEXT_ACTION_MASK.value, None
                    # )
                    # if next_action_mask is not None:
                    #     new_next_action = misc.masked_logits(
                    #         perturb[i] + next_actions[i],
                    #         mask=next_action_mask,
                    #         explore=False,
                    #         normalize=True,
                    #     )
                    # else:
                    new_next_action = next_actions[i] + perturb[i]
                else:
                    new_next_action = next_actions[i]
                new_next_actions.append(new_next_action)
                # print("old and new:", next_logits[i], next_logits[i] + perturb[i], perturb[i], new_next_logit)

            adv_critic_in = torch.cat([critic_obs] + new_next_actions, dim=-1)
            next_value = self._target_centralized_critic(adv_critic_in).view(-1)

        if self._mode == "independent":
            # independent mode support only one agent
            ego_agent = self._ego_agents[0]
            reward = batch[ego_agent][EpisodeKeys.REWARD.value]
            done = batch[ego_agent][EpisodeKeys.DONE.value]
            assert reward.shape == next_value.shape == done.shape, (
                reward.shape,
                next_value.shape,
                done.shape,
            )
            target_value = reward + gamma * next_value * (1.0 - done)
        elif self._mode == "cooperateive":
            target_value = torch.mean(
                torch.cat(
                    [
                        batch[aid][EpisodeKeys.REWARD.value].view((-1, 1))
                        + gamma
                        * (1.0 - batch[aid][EpisodeKeys.DONE.value].view((-1, 1)))
                        * next_value
                        for aid in self._ego_agents
                    ],
                    dim=-1,
                ),
                dim=-1,
            )
        else:
            raise NotImplementedError
        value_loss = F.mse_loss(value, target_value.detach())
        return value_loss, target_value, value

    def __call__(
        self,
        sampler: SamplerInterface,
        stage: str,
        time_step: int,
        agent_filter: Sequence = None,
        n_inner_loop: int = 1,
    ):
        assert stage is not None, "Stage cannot be none"
        self._training_stage = stage
        self.loss_func.set_training_stage(stage)
        # assert ego agents in agent filters
        agent_filter = agent_filter or self._involved_agents
        return super(CentralizedTrainer, self).__call__(
            sampler, time_step, agent_filter=agent_filter, n_inner_loop=n_inner_loop
        )

    def _before_loss(self, policy: Dict[AgentID, Policy], batch: Dict[AgentID, Dict]):
        """Steps before loss computation, including centralized critic optimization
        and grads transmission from the centralized critic to independent actors.

        Args:
            policy (Dict[AgentID, Policy]): A dict of agent policies.
            batch (Dict[AgentID, Dict]): A dict of agent batches.

        Returns:
            [type]: [description]
        """

        # convert numpy dataarray into tensors
        caster = (
            lambda x: torch.FloatTensor(x.copy()).to(
                device="cpu", dtype=default_dtype_mapping(x.dtype)
            )
            if not isinstance(x, torch.Tensor)
            else x
        )
        batch = data_utils.walk(caster, batch)

        # compute centralized value loss
        agent_obs = {
            aid: batch[aid][EpisodeKeys.OBSERVATION.value]
            for aid in self._involved_agents
        }
        # has been transformed in the rollout stage.
        # critic_obs = self._obs_preprocessor.transform(global_state)
        # next_critic_obs = self._obs_preprocessor.transform(global_next_state)
        action_dists = [
            batch[aid][EpisodeKeys.ACTION_DIST.value] for aid in self.involved_agents
        ]

        if self._training_stage == "critic":
            agent_next_obs = {
                aid: batch[aid][EpisodeKeys.NEXT_OBSERVATION.value]
                for aid in self._involved_agents
            }
            if self._use_global_state:
                global_state = list(batch.values())[0][EpisodeKeys.GLOBAL_STATE.value]
                global_next_state = list(batch.values())[0][
                    EpisodeKeys.NEXT_GLOBAL_STATE.value
                ]
                critic_obs = global_state
                next_critic_obs = global_next_state
            else:
                # global_state = agent_obs
                # global_next_state = agent_next_obs
                critic_obs = torch.cat(
                    [agent_obs[aid] for aid in self.involved_agents], dim=-1
                )
                next_critic_obs = torch.cat(
                    [agent_next_obs[aid] for aid in self.involved_agents], dim=-1
                )
            next_action_dists = []
            for aid in self.involved_agents:
                # compute next actions with target policies
                _, next_action_dist, _ = self._policy[aid].compute_actions(
                    agent_next_obs[aid],
                    use_target=True,
                    action_mask=batch[aid].get(
                        EpisodeKeys.NEXT_ACTION_MASK.value, None
                    ),
                    explore=False,
                )
                # make sure that a is tensor
                if self._adversary:
                    assert (
                        next_action_dist.requires_grad
                    ), "Adversary mode is on, while next_action_dist.requires_grad={}".format(
                        next_action_dist.requires_grad
                    )
                next_action_dists.append(next_action_dist)
            self._critic_optimizer.zero_grad()
            value_loss, target_value, value = self._compute_value_loss(
                critic_obs,
                action_dists,
                next_critic_obs,
                next_action_dists,
                batch,
                adversary=self._adversary,
            )
            value_loss.backward()
            grad_cliprange = self._training_config["grad_norm_clipping"]
            _ = torch.nn.utils.clip_grad_norm_(
                self._centralized_critic.parameters(), grad_cliprange
            )
            self._critic_optimizer.step()
            training_info = {
                "value_loss": value_loss.detach().item(),
                "target_value_mean": target_value.mean().item(),
                "target_value_max": target_value.max().item(),
                "target_value_min": target_value.min().item(),
                "eval_value_mean": value.mean().item(),
                "eval_value_max": value.max().item(),
                "eval_value_min": value.min().item(),
            }
        elif self._training_stage == "actor":
            if self._use_global_state:
                global_state = list(batch.values())[0][EpisodeKeys.GLOBAL_STATE.value]
                critic_obs = global_state
            else:
                critic_obs = torch.cat(
                    [agent_obs[aid] for aid in self.involved_agents], dim=-1
                )
            self.agent_loss_func.zero_grad()
            actions = [
                batch[aid][EpisodeKeys.ACTION_DIST.value]
                for aid in self.involved_agents
            ]
            policy_loss = self._compute_policy_loss(
                agent_obs, critic_obs, actions, batch, adversary=self._adversary
            )
            policy_loss.backward()
            training_info = {
                "policy_loss": policy_loss.detach().item(),
            }
        return {aid: batch for aid in self._ego_agents}, training_info

    def _after_loss(self, policy, step_counter: int):
        # sync with update
        if self._training_stage == "critic":
            misc.soft_update(
                self._target_centralized_critic,
                self._centralized_critic,
                tau=self._training_config["tau"],
            )
        # move outer
        # for policy in self._policy.values():
        #     policy.update_target(tau=self._training_config["tau"])

    def reset(self, policy_instance: Policy, configs: Dict = None):
        """Return centralized trainer, required.

        Args:
            policy_instance (Policy): A policy instance.
            configs (Dict, optional): Training configurations. Defaults to None.
        """

        self._step_counter = 0
        # if policy_instance is not self._policy:
        self._policy = policy_instance or self._policy
        assert isinstance(
            self._policy, Dict
        ), "Centralized trainer requires a dict of policies as a whole policy"
        if self._centralized_critic is None:
            self._build_centralized_critics_and_optimizer()
        self.agent_loss_func.reset(self._policy, configs or self._training_config)
