import os
import yaml
import torch
import ray
import pickle
import numpy as np
import random
import threading

from collections import defaultdict, namedtuple

import torch.nn.functional as F

from expground.types import (
    PolicyID,
    AgentID,
    PolicyConfig,
    Dict,
    Union,
    Tuple,
    Sequence,
    List,
    Iterator,
    Any,
)
from expground.logger import Log
from expground.utils.format import POLICY_NAME_GEN
from expground.utils.path import load_class_from_str, parse_env_config
from expground.utils.preprocessor import get_preprocessor
from expground.algorithms import misc
from expground.algorithms.base_policy import (
    ActionDist,
    Logits,
    Policy,
    PolicyStatus,
    Action,
)
from expground.algorithms.meta_policy import LearnableDistribution


ActiveTup = namedtuple("ActiveTup", "pid,status")


class ActivePolicyBucket:
    def __init__(self, agent_id) -> None:
        """ActivePolicyBucket maintains a list of `ActiveTup`, which indicates the"""
        self._active_tups: List[ActiveTup] = []
        self._lock = threading.Lock()
        self._agent_id = agent_id

    def __setstate__(self, d):
        self.__dict__ = d
        self.__dict__["_lock"] = threading.Lock()

    def __getstate__(self):
        res = {}
        for k, v in self.__dict__.items():
            if k == "_lock":
                continue
            res[k] = v
        return res

    def pop(self) -> PolicyID:
        """Pop the lowest active policy id.

        Returns:
            PolicyID: Policy id of policy which is LOWEST_ACTIVE.
        """
        with self._lock:
            if (
                len(self._active_tups) < 1
                or self._active_tups[0].status > PolicyStatus.LOWEST_ACTIVE
            ):
                raise NotImplementedError
            else:
                item = self._active_tups[0]
                if len(self._active_tups) > 1:
                    head = self._active_tups[1]
                    update_item = ActiveTup(head.pid, PolicyStatus.LOWEST_ACTIVE)
                    self._active_tups = [update_item] + self._active_tups[2:]
                else:
                    self._active_tups = self._active_tups[1:]
                return item

    def stack(self, policy_id: PolicyID):
        """Stack a policy id and make it be active.

        Args:
            policy_id (PolicyID): Policy id of an active policy
        """
        with self._lock:
            if len(self._active_tups) < 1:
                item = ActiveTup(policy_id, PolicyStatus.LOWEST_ACTIVE)
            else:
                item = ActiveTup(policy_id, PolicyStatus.ACTIVE)
            self._active_tups.append(item)

    def lowest_active_tup(self) -> Union[ActiveTup, None]:
        """Return an ActiveTup instance indicates the lowest active policy id if there is no empty ActiveTup list,
        otherwise `None`.

        Returns:
            Union[ActiveTup, None]: ActiveTupe or not.
        """

        if len(self._active_tups) > 0:
            return self._active_tups[0]
        else:
            return None

    def active_tups(self) -> List[ActiveTup]:
        """Return a list of active tups. Empty list for length < 1.

        Returns:
            List[ActiveTup]: A list of ActiveTup.
        """

        return self._active_tups[1:]

    def __iter__(self) -> Iterator[ActiveTup]:
        yield from self._active_tups


class PolicyPool(Policy):
    """PolicyPool maintains a policy pool for an agent and its corresponding sample distribution. To support more
    general cases, we default the distribution is learnable.
    """

    def __init__(
        self,
        agent_id: AgentID,
        policy_config: PolicyConfig,
        pool_size: int = -1,
        start_fixed_support_num: int = 0,
        start_active_support_num: int = 0,
        is_fixed: bool = True,
        mixed_at_every_step: bool = False,
        distribution_training_kwargs: Dict = None,
        distill_mode: bool = False,
    ):
        """Initialize a policy pool instance.

        Args:
            agent_id (AgentID): The related agent id.
            policy_config (PolicyConfig): The policy configuration to generate new policies.
            pool_size (int, optional): Indicates the capacity of policy size. Defaults to -1.
            is_fixed (bool, optional): Indicates whether the distribution of this pool is learnable or not. Defaults to True means not learnable.
            mixed_at_every_step (bool, optional): Indicates whether we sample policy at each time step. Defaults to True.
        """

        super(PolicyPool, self).__init__(
            policy_config.observation_space,
            policy_config.action_space,
            policy_config.model_config,
            policy_config.custom_config,
            is_fixed,
        )

        self._agent_id = agent_id
        self._start_fixed_support_num = start_fixed_support_num
        self._start_active_support_num = start_active_support_num
        self._pool_size = pool_size
        # counter is used to identify the size of policy pool
        self._counter = 0
        self._pid_suffix = 0
        self._policies = dict()
        self._active_pids = defaultdict(lambda: 0)
        self._mixed_at_every_step = mixed_at_every_step
        self._preprocessor = get_preprocessor(
            policy_config.observation_space,
            policy_config.custom_config.get("preprocess_mode", "flatten"),
        )(policy_config.observation_space)

        if not is_fixed:
            assert (
                distribution_training_kwargs is not None
            ), "PolicyPool is not fixed, so distribution training kwargs should be given!"
        distribution_training_kwargs = distribution_training_kwargs or {}
        self._distribution_training_kwargs = distribution_training_kwargs
        self._distribution = LearnableDistribution(
            self._policies, **distribution_training_kwargs
        )
        self._behavior_policy = None
        self._behavior_pid = None
        self._policy_config: PolicyConfig = policy_config

        # a policy pool is always active
        self.switch_status_to(PolicyStatus.ACTIVE)
        self.active_bucket = ActivePolicyBucket(self._agent_id)

        # create supports
        for _ in range(start_fixed_support_num):
            self.add_policy(fixed=True)

        for _ in range(start_active_support_num):
            self.add_policy(fixed=False)

        self._state = {
            "_distribution": self._distribution,
            "_policy_config": self._policy_config,
            "_agent_id": self._agent_id,
            "_pool_size": self._pool_size,
            "_learnable_dist": self.is_fixed,
            "_observation_space": self._observation_space,
            "_action_space": self._action_space,
            "_behavior_policy": self._behavior_policy,
        }

        self.distill_mode = distill_mode

    @classmethod
    def as_remote(
        cls,
        num_cpus: int = None,
        num_gpus: int = None,
        memory: int = None,
        object_store_memory: int = None,
        resources: dict = None,
    ) -> type:
        """Return a remote class for Actor initialization"""

        return ray.remote(
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            memory=memory,
            object_store_memory=object_store_memory,
            resources=resources,
        )(cls)

    @property
    def preprocessor(self):
        return self._preprocessor

    @property
    def pool_size(self) -> int:
        """Return the capacity of policy pool

        Returns:
            [type]: [description]
        """
        return self._pool_size

    def __len__(self) -> int:
        return self._counter

    def return_dominated_policies(self) -> List[PolicyID]:
        """Remove lowest prob policy.

        Returns:
            List[PolicyID]: [description]
        """

        dist_dict = self._distribution.dict_values()
        threshold = min(0.09, 1 / max(1, len(self._policies)))
        # control the policy pool
        max_num = 10
        res = []
        if len(dist_dict) >= max_num:
            # remove one:
            candidates = list(dist_dict.items())
            candidates = sorted(candidates, key=lambda x: x[1])
            for i in range(len(dist_dict) - max_num):
                if candidates[i][1] <= threshold:
                    res.append(candidates[i][0])
            # for k, v in dist_dict.items():
            #     if v < threshold and max_num:
            #         res.append(k)
            #         max_num -= 1
        return res

    def remove_dominated_policies(self, pids: List[PolicyID]):
        for pid in pids:
            self._distribution.pop(pid)
            self._policies.pop(pid)
            self._counter -= 1

    def get_fixed_policies(self, device="cpu") -> Dict[PolicyID, Policy]:
        """Return a dict of fixed policies, mapping from policy id to policy instances.

        Returns:
            Dict[PolicyID, Policy]: A dict of policies.
        """

        res = {}
        for pid, policy in self._policies.items():
            if policy.is_fixed:
                res[pid] = policy.to(device)
        return res

    def get_fixed_policy_ids(self) -> List[PolicyID]:
        """Return a list of policy id.

        Returns:
            List[PolicyID]: A list of policy id
        """
        return [pid for pid, policy in self._policies.items() if policy.is_fixed]

    def get_active_policies(self, device=None) -> Dict[PolicyID, Policy]:
        """Return a dict of active policies.

        Returns:
            Dict[PolicyID, Policy]: A dict of active policies.
        """

        res = {}
        for pid in self._active_pids:
            res[pid] = self._policies[pid].to(device)
            assert not self._policies[pid].is_fixed
        return res

    def get_active_policy_ids(self) -> List[PolicyID]:
        """Return a list of active policy ids.

        Returns:
            List[PolicyID]: A list of active policy ids.
        """

        res = list(self._active_pids.keys())
        # res = [pid for pid, policy in self._policies.items() if not policy.is_fixed]
        return res

    def get_lowest_active(self) -> Union[Tuple[PolicyID], None]:
        tup = self.active_bucket.lowest_active_tup
        if tup is None:
            return None
        else:
            return (tup.pid,)

    def get_active(self) -> Union[Tuple[PolicyID], None]:
        tups = self.active_bucket.active_tups
        if len(tups) < 1:
            return None
        else:
            return tuple([e.pid for e in tups])

    def get_policies(self) -> Dict[PolicyID, Policy]:
        return self._policies

    def get_policy_ids(self) -> Tuple[PolicyID]:
        return tuple(self._policies.keys())

    def get_pool_size(self) -> int:
        """Return the size of policy pool.

        Returns:
            int: A integer indicates the policy pool size.
        """

        return len(self._policies)

    def get_policy(self, pid: PolicyID) -> Policy:
        """Return an policy instance tagged with `pid`.

        Args:
            pid (PolicyID): The policy id of retrieved policy instance

        Returns:
            Policy: An policy instance
        """

        return self._policies[pid]

    def get_distribution(self) -> Dict[PolicyID, float]:
        """Return a dict of policy distribution.

        Returns:
            Dict[PolicyID, float]: A dict of policy distribution.
        """

        return self._distribution

    def aggregate(self, values: Dict[PolicyID, Any]):
        res = []
        prob_dict = self._distribution.dict_values()
        for pid, v in values.items():
            res.append(prob_dict[pid] * v)
        res = np.asarray(res)
        res = np.sum(res, axis=0)
        return res

    def compute_action(
        self, observation, action_mask, evaluate
    ) -> Tuple[Action, ActionDist, Logits]:
        """For rollout.

        Args:
            observation ([type]): [description]
            action_mask ([type]): [description]
            evaluate ([type]): [description]

        Returns:
            [type]: [description]
        """

        if self.distill_mode and self.is_fixed:
            values = {}
            for pid, policy in self._policies.items():
                vs = policy.value_function(observation, action_mask)
                values[pid] = vs
            # aggregate with opponent mixtures
            agg_value = self.aggregate(values, mode="mixed_oracle")
            action = np.argmax(agg_value)
            prob = np.zeros(self._action_space.n)
            prob[action] = 1.0
            logits = prob.copy()
        else:
            policy = self._behavior_policy
            if policy is None:
                assert (
                    self._mixed_at_every_step
                ), "policy needs to be specified when `mixed_at_step` mode is off."
                policy_id = (
                    self._distribution.sample()
                    if len(self._active_pids) == 0
                    else random.choice(list(self._active_pids.keys()))
                )
                policy = self._policies[policy_id]

            action, prob, logits = policy.compute_action(
                observation, action_mask, evaluate
            )
            logits = F.normalize(torch.from_numpy(logits), dim=-1, p=2.0).numpy()
        return action, prob, logits

    def compute_actions(
        self,
        observation,
        use_target: bool = False,
        action_mask=None,
        explore: bool = False,
    ) -> Tuple[Action, ActionDist, Logits]:
        """For training.

        Args:
            observation ([type]): [description]
            use_target (bool, optional): [description]. Defaults to False.
            action_mask ([type], optional): [description]. Defaults to None.
            explore (bool, optional): [description]. Defaults to False.

        Returns:
            [type]: [description]
        """
        if len(self._active_pids) > 0:
            assert len(self._active_pids) == 1, self._active_pids
            policy_id = random.choice(list(self._active_pids.keys()))
            policy = self._policies[policy_id]
            actions, dists, logits = policy.compute_actions(
                observation, use_target, action_mask, explore
            )
            logits = F.normalize(logits, p=2.0, dim=-1)
            return actions, dists, logits
        else:
            if not use_target:
                # meta_probs = self._distribution.probs()
                meta_probs = misc.gumbel_softmax(
                    self._distribution.tensor, hard=True, explore=False
                )
            else:
                # meta_probs = self._distribution.target_probs()
                meta_probs = misc.gumbel_softmax(
                    self._distribution.target_tensor, hard=True, explore=False
                )
            # selected policy is
            idx = meta_probs.argmax(-1).item()
            pid = self._distribution._supports[idx]
            policy = self._policies[pid]
            raw_actions, raw_dists, raw_logits = policy.compute_actions(
                observation, use_target, action_mask, False
            )
            actions = meta_probs[idx] * raw_actions.detach()
            dists = meta_probs[idx] * raw_dists.detach()
            logits = F.normalize(raw_logits, p=2.0, dim=-1)
            logits = meta_probs[idx] * logits.detach()

            # stochastic mode
            # if not use_target:
            #     meta_probs = self._distribution.probs()
            # else:
            #     meta_probs = self._distribution.target_probs()

            # # (n_support, n_batch, n_actions)
            # rets = [
            #     x.compute_actions(observation, use_target, action_mask, False)
            #     for x in self._policies.values()
            # ]
            # raw_actions, raw_dists, raw_logits = list(
            #     map(lambda x: torch.stack(x).detach(), zip(*rets))
            # )

            # # always detach underlying policy supports, action_logits shape: (n_batch, n_action)
            # actions = torch.sum(meta_probs.view((-1, 1, 1)) * raw_actions.detach(), dim=0)
            # dists = torch.sum(meta_probs.view((-1, 1, 1)) * raw_dists.detach(), dim=0)
            # # we need to normalized logits
            # logits = F.normalize(raw_logits, p=2.0, dim=-1)
            # logits = torch.sum(meta_probs.view((-1, 1, 1)) * logits.detach(), dim=0)
            # if len(self._policies) > 1:
            #     import pdb; pdb.set_trace()

            return actions, dists, logits

    def update_target(self, **kwargs):
        # we only update active policies
        for policy in self.get_active_policies().values():
            policy.update_target(**kwargs)
            self._distribution.update_target(**kwargs)

    def sync_weights(self):
        """Sync weights from active policy to opponent fixed policy, only for self-play with pool_size=2."""
        assert (
            len(self._policies) == 2
        ), "Cannot sync weights since policies size largers than 2. {}".format(
            list(self._policies.keys())
        )
        assert (
            len(self._active_pids) == 1
        ), "Cannot sync weights since active pids is not 1: {}".format(
            self._active_pids
        )

        source_pid = list(self._active_pids.keys())[0]
        target_pid = "policy-0"
        assert source_pid != target_pid, (source_pid, target_pid)

        old_device_is_cuda = self._policies[source_pid].use_cuda
        source_policy_cpu = self._policies[source_pid].to("cpu")
        # import pdb; pdb.set_trace()
        self._policies[target_pid].update_parameters(source_policy_cpu.parameters())
        if old_device_is_cuda:
            self._policies[source_pid] = self._policies[source_pid].to("cuda")

    def set_fixed(self, policy_ids: Sequence):
        for pid in policy_ids:
            self._policies[pid].is_fixed = True
            self._policies[pid] = self._policies[pid].to("cpu")
            self._policies[pid].switch_status_to(PolicyStatus.FIXED)
            self._active_pids[pid] -= 1
            assert self._active_pids[pid] >= 0
            if self._active_pids[pid] == 0:
                self._active_pids.pop(pid)

    def sync_policies(
        self,
        policies: Dict[PolicyID, Policy],
        sync_type: str = None,
        replace_old: bool = False,
        is_ego: bool = False,
        hard: bool = False,
    ):
        """Considering to sync, not replace.

        Args:
            policies (Dict[PolicyID, Policy]): A dict of policies
            sync_type (str): sync type: full, intersection, union
            replace_old (bool): if policy in old_key, replace it with new one.
            hard (bool): Hard sync. Defaults to False.
        """
        if sync_type is None:
            sync_type = "full"

        new_keys = list(policies.keys())
        old_keys = list(self._policies.keys())

        if hard:
            sync_type = "union"

        if sync_type == "full":
            for k in old_keys:
                if k not in new_keys:
                    self._policies.pop(k)
                    self._distribution.pop(k)
        elif sync_type == "intersection":
            new_keys = []
            for nk in policies.keys():
                if nk in old_keys:
                    new_keys.append(nk)
        elif sync_type == "union":
            pass
        for k in new_keys:
            if k in self._policies and not replace_old:
                continue
            if not is_ego:
                policies[k].is_fixed = True
            if self._policies.get(k) is not None:
                policies[k].is_fixed = self._policies.get(k).is_fixed
            self._policies[k] = policies[k].to("cpu")
            if self._distribution.tensor is None:
                self._distribution[k] = 1.0 if policies[k].is_fixed else 0.0
            else:
                self._distribution[k] = (
                    torch.max(self._distribution.tensor).data.item()
                    * np.random.random()
                    if policies[k].is_fixed
                    else 0.0
                )

    def add_policy(
        self,
        key: PolicyID = None,
        policy: Policy = None,
        statistic=None,
        fixed: bool = True,
    ):
        """Add new policy by evaluating the rollout statistic results. If the policy id has been bound with an
        existing policy, then the transferred policy will be ignored.

        :param key: PolicyID, policy id
        :param policy: Policy, the registered policy instance
        :param statistic: Any, statistic, if valued None
        """

        if key is None or policy is None:
            # generate new policy
            policy = self._policy_config.new_policy_instance(self._agent_id)
            # assign new policy id
            policy_tag_name = (
                self._policy_config.human_readable or self._policy_config.policy
            )
            while True:
                key = key or POLICY_NAME_GEN(policy_tag_name, self._pid_suffix)
                self._pid_suffix += 1
                if key not in self._policies:
                    break
                # count uncounted policies!
                self._counter += 1

        if key not in self._policies:
            # return None, None
            self._counter += 1
        use_cuda = policy._custom_config.get("use_cuda", False)
        self._policies[key] = policy.to("cpu" if (fixed or not use_cuda) else "cuda")
        policy.is_fixed = fixed

        # pre-assigned distribution weight to added policy, zero for active policies always.
        self._distribution[key] = 1.0 if fixed else 0.1
        if not fixed:
            self.active_bucket.stack(policy_id=key)
            self._active_pids[key] += 1
        else:
            policy.switch_status_to(PolicyStatus.FIXED)
        return key, policy

    def set_distribution(self, equilibrium: Dict[PolicyID, float]):
        """Set policy distribution from equilibrium before interacting.

        Args:
            equilibrium (Dict[PolicyID, float]): Policy distribution based on some equilibrium
        """

        self._distribution.set_states(probs=equilibrium)
        # self._behavior_policy = self._distribution.sample()
        # update state
        # self._state["_behavior_policy"] = self._behavior_policy

    def optimize(self, batch):
        """For the learnable distribution"""
        self._distribution.optimize(batch)

    def export(self, data_dir):
        """Export to local storage"""

        for pid, policy in self._policies.items():
            policy.save(path=os.path.join(data_dir, pid))

        # state keep
        with open(os.path.join(data_dir, "state.pkl"), "w") as f:
            pickle.dump(self._state, f)

    def load(self, data_dir):
        # load state first
        with open(os.path.join(data_dir, "state.pkl", "r"), "r") as f:
            self._state = pickle.load(f)

        for k, v in self._state.items():
            setattr(self, k, v)

        # set population
        for k in self._distribution.keys():
            self._policies[k] = self._policy_config.new_policy_instance(self._agent_id)
            dir_path = os.path.join(data_dir, k)
            self._policies[k].load(dir_path)

    def reset(
        self,
        is_active: bool = False,
        policy_id: PolicyID = None,
    ):
        """Reset agent interface status. Currently, reset can only reset behavior policy if the policy id is
        specified.

        Args:
            is_active (bool, optional): Indicates whether behavior policy is fixed or not
            policy_id (PolicyID, optional): The policy id related to a fixed policy. Defaults to None.
        """

        self._behavior_policy = None
        self._behavior_pid = None
        # if policy id has been specified, is_active will be ignored
        if policy_id is not None:
            assert policy_id in self._policies, (policy_id, list(self._policies.keys()))
            self._behavior_policy = self._policies[policy_id]
            self._behavior_pid = policy_id
        else:
            # if policy id is not specified, consider whether the mixed mode
            if not self._mixed_at_every_step:
                policy_id = (
                    self._distribution.sample()
                    if not is_active
                    else random.choice(list(self._active_pids.keys()))
                )
                self._behavior_pid = policy_id
                self._behavior_policy = self._policies[policy_id]
                assert self._behavior_policy.is_fixed == (not is_active)

    def add_transition(self, acc_reward: float):
        if not self.is_fixed:
            assert (
                self._behavior_pid is not None
            ), "Do not support mix_at_every_step mode yet."
            self._distribution.add_transition(self._behavior_pid, acc_reward)

    def update_pool(self, policies):
        self._policies.update(policies)

    def save(self, path, global_step=0, hard: bool = False):
        for pid, policy in self._policies.items():
            fpath = os.path.join(path, "{}.pt".format(pid))
            if not hard and os.path.exists(fpath):
                continue
            else:
                Log.info("\t- save model as: {}".format(fpath))
                policy.save(fpath, hard=hard)

    @classmethod
    def load_from_config(cls, model_dir, yaml_path, model_keys, aid="default"):
        # parse yaml
        with open(yaml_path, "r") as f:
            config = yaml.safe_load(f)
        env_desc, _ = parse_env_config(config["env_config"])
        env_config = env_desc["config"]
        policy_cls = (
            load_class_from_str("expground.algorithms", config["algorithm"]["policy"])
            if config.get("algorithm")
            else None
        )
        policy_config = PolicyConfig(
            policy=policy_cls,
            mapping=lambda agent: agent,
            observation_space=env_config["observation_spaces"][aid],
            action_space=env_config["action_spaces"][aid],
            custom_config=config.get("custom_config", {}),
            model_config=config.get("model_config", {}),
        )
        pool = cls(
            agent_id=aid,
            policy_config=policy_config,
            pool_size=len(model_keys),
            start_fixed_support_num=0,
            start_active_support_num=0,
            is_fixed=True,
            mixed_at_every_step=False,
            distribution_training_kwargs=None,
            distill_mode=False,
        )
        for k in model_keys:
            pool.add_policy(k)
            pool._policies[k].load(model_dir / k)
        return pool


__all__ = ["PolicyPool"]
