import copy
import torch
import torch.nn.functional as F

from torch import optim
from torch.distributions import Categorical
from gym import spaces

from expground.types import Sequence, AgentID, Dict, Any, Union, List
from expground.logger import Log
from expground.utils.data import EpisodeKeys, default_dtype_mapping
from expground.utils.preprocessor import get_preprocessor
from expground.utils import data as data_utils
from expground.utils.sampler import SamplerInterface

from expground.algorithms import misc
from expground.algorithms.ddpg.config import CENTRALIZED_CRITIC_NETWORK
from expground.algorithms.loss_func import LossFunc
from expground.algorithms.base_policy import Policy
from expground.algorithms.base_trainer import Trainer

from expground.common.policy_pool import PolicyPool
from expground.common.models import get_model


def _group_validation(
    groups: Dict[str, List[AgentID]], possible_agents: List[AgentID]
) -> bool:
    """Validate whether the groups cover all possible agents, and no overlapping among them.

    Args:
        groups (Dict[str, List[AgentID]]): A dict of groups
        possible_agents (List[AgentID]): A list of agents

    Returns:
        bool: True for validation pass, otherwise False
    """

    all_agents = []
    for group in groups.values():
        all_agents += group
    assert len(all_agents) == len(
        possible_agents
    ), f"Agents should be no overlapped and no lacking: {groups} {possible_agents}"
    return True


class ppool_loss(LossFunc):
    def __init__(self, mute_critic_loss: bool = True):
        super().__init__(mute_critic_loss=mute_critic_loss)
        self.training_stage = None

    def __call__(self, batch) -> Dict[str, Any]:
        # reset optimizers here
        return {}

    def zero_grad(self):
        # always reset optimizer here
        self.setup_optimizers()
        self.optimizers.zero_grad()

    def step(self) -> Any:
        if self.training_stage == "actor":
            assert self.policy._distribution.tensor.grad is not None, (
                self.policy._distribution.tensor.requires_grad,
                self.policy._distribution.tensor.grad,
            )
            self.optimizers.step()
        # sync logits
        self.policy._distribution.sync_logits()
        # self.policy._distribution._distribution = Categorical(probs=self.policy._distribution.probs())
        # if len(self.policy._distribution.tensor) > 1:
        #     print("new dist:", self.policy._distribution._logits, self.policy._distribution.probs())

    def setup_optimizers(self, *args, **kwargs):
        # 0.001: top right, 0.01: others
        self.optimizers = optim.SGD([self.policy._distribution.tensor], lr=0.05)

    def reset(self, policy: Policy, configs: Dict[str, Any]):
        self._params.update(configs or {})
        # print("\t- loss func reset policy:", policy)
        if policy is not None:
            self._policy = policy
        self.optimizers = None
        self.loss = []
        self.setup_optimizers()
        self.setup_extras()


def _loss_wrap(
    policies: Dict[AgentID, Policy],
    training_config: Dict[str, Any],
    loss_func_cls: type,
) -> type:
    """Wraps multiple loss functions for multiple policies, and behaves like a general loss function.

    Args:
        policies (Dict[AgentID, Policy]): The dict of ego agents' policies, mapping from agent to policies.
        training_config (Dict[str, Any]): The dict of training configuration, some hyper parameters.
        loss_func_cls (type): The class used to Create loss function instance.

    Returns:
        type: A general loss function accepts a dict of agent batches.
    """

    # create policies' loss functions by muting the critic loss update
    _loss_funcs = {
        k: ppool_loss(mute_critic_loss=True)
        if isinstance(policy, PolicyPool) and not policy.is_fixed
        else loss_func_cls(mute_critic_loss=True)
        for k, policy in policies.items()
    }
    # for k, policy in policies.items():
    #     _loss_funcs[k].reset(policy, training_config)

    class _loss(LossFunc):
        def __init__(self):
            # mute critic loss must always be False
            mute_critic_loss = True
            super(_loss, self).__init__(mute_critic_loss=mute_critic_loss)

        def set_training_stage(self, stage: str):
            for loss in _loss_funcs.values():
                loss.training_stage = stage

        def __call__(self, batch: Dict[str, Dict]) -> Dict[str, Any]:
            """Integrate loss computation for each policy.

            Args:
                batch (Dict[str, Dict]): A dict of batch, mapping from policy ids to corresponding batches.

            Returns:
                Dict[str, Any]: A dict of loss statistic, mapping from policy ids to corresponding loss feedabck.
            """

            res = {}
            for k in policies:
                # here, batch[k] could be an agent dict or pure dict
                tmp = _loss_funcs[k](batch[k])
                if len(tmp) == 0:
                    continue
                else:
                    res[k] = tmp
            return res

        def step(self) -> Any:
            _ = [loss.step() for loss in _loss_funcs.values()]

        def zero_grad(self):
            _ = [loss.zero_grad() for loss in _loss_funcs.values()]

        def setup_optimizers(self, *args, **kwargs):
            _ = [
                loss.setup_optimizers(*args, **kwargs) for loss in _loss_funcs.values()
            ]

        def reset(self, policies: Dict[AgentID, Policy], configs: Dict[str, Any]):
            for k, _loss_func in _loss_funcs.items():
                policy = policies[k]
                if isinstance(policy, PolicyPool):
                    # got active policy of it
                    # check whether it is fixed or not
                    if not policy.is_fixed:
                        _loss_func.reset(policy, configs)
                    else:
                        # check whether there are active policy, if not, raise Error
                        active_policies = policy.get_active_policies()
                        if len(active_policies) < 1:
                            Log.warning("No active policies found!")
                        else:
                            # Log.info("Got active policy, train it!")
                            policy = list(active_policies.values())[0]
                            _loss_func.reset(policy, configs)
                else:
                    # do we really need to do reset?
                    _loss_func.reset(policy, configs)

    loss = _loss()
    # loss.reset(policies, training_config)
    return loss
