"""
Base class of learners. In general, a learner is responsible for the training of a pool of policies, but it is not mandatory.
It has no need to maintain a policy pool for each agent in the case of traditional MARL cases, but in the case of population learning,
such as Self-play, PSRO, it is indeed necessary.

So, if your learner needs to be compatible with both of them, the policy initialization should be an option of policy and policy pool.
"""

import os
import copy
import ray
import time

import numpy as np

from abc import ABCMeta, abstractmethod
from collections import defaultdict, namedtuple

from ray.util import ActorPool
from torch.utils import tensorboard

from expground import settings
from expground.envs.agent_interface import AgentInterface
from expground.types import (
    Any,
    Dict,
    PolicyID,
    RolloutConfig,
    Sequence,
    AgentID,
    Tuple,
    LambdaType,
    List,
    Union,
)
from expground.logger import Log, monitor
from expground.utils import rollout
from expground.algorithms.base_policy import Policy
from expground.common.policy_pool import PolicyPool


LearnerState = namedtuple(
    "LearnerState", "total_timesteps, total_episodes, global_step"
)


class Learner(metaclass=ABCMeta):

    NAME = "Learner"

    def __init__(
        self,
        experiment: str,
        env_desc: Dict[str, Any],
        rollout_config: RolloutConfig,
        summary_writer=None,
        seed: int = None,
        ray_mode: bool = False,
        enable_policy_pool: bool = False,
        agent_mapping: LambdaType = lambda agent: agent,
        exp_config=None,
        resource_config: Dict[str, Any] = None,
        policy_pool_config: Dict[str, Any] = None,
        enable_evaluation_pool: bool = False,
    ) -> None:
        """Learner initialization.

        Args:
            experiment (str): Given experiment name, will be used for building Log backup.
            summary_writer (tensorboard.SummaryWriter, optional): Specify the summary writer. Defaults to None.
            seed (int, optional): Random seed for rollout and numpy. Defaults to None.
            ray_mode (bool, optional): Indicate use ray or not. Defaults to False.
            agent_mapping (LambdaType, optional): The mapping function controls the relationship between agents and policies/policypools. Defaults to one to one mapping.
            resource_config (Dict[str, Any], optional): A dict describes the computing resources. Defaults to None.
        """

        np.random.seed(seed)

        resource_config = resource_config or {}

        self.summary_writer = summary_writer or tensorboard.SummaryWriter(
            log_dir=exp_config.log_path
        )
        self.state_dir = os.path.join(exp_config.log_path)
        self.experiment_tag = experiment

        self._global_step = 0
        self._total_timesteps = 0
        self._total_episodes = 0
        self._agent_mapping = agent_mapping
        self._exp_config = exp_config
        self._evaluation_pool = None

        n_env = 1
        if rollout_config.vector_mode:
            num_simulation = rollout_config.num_simulation
            fragment_length = rollout_config.fragment_length
            max_step = rollout_config.max_step
            n_env = max(num_simulation, fragment_length // max_step)

        # evaluation pool turn on in ray mode
        if enable_evaluation_pool:
            preset_worker_num = resource_config.get("evaluation_worker_num", 1)
            assert (
                preset_worker_num > 0
            ), "Remote evaluation pool is on, but worker num is smaller than 1: {}".format(
                preset_worker_num
            )
            self._evaluation_worker_num = preset_worker_num
            Log.info("Get a resource of %s", resource_config)
            if preset_worker_num > 1:
                Log.info(
                    "create remote evaluation pool detecting resource: {}".format(
                        ray.available_resources()
                    )
                )
                EvaluationActor = ray.remote(
                    num_cpus=resource_config.get("evaluation_worker_total_cpu_num", 1)
                    / preset_worker_num,
                    num_gpus=None,
                    memory=None,
                    object_store_memory=None,
                    resources=None,
                )(rollout.Evaluator)
                self._evaluation_pool = ActorPool(
                    [
                        EvaluationActor.remote(
                            env_desc,
                            n_env,
                            use_remote_env=ray_mode and rollout_config.remote_env,
                        )
                        for _ in range(self._evaluation_worker_num)
                    ]
                )
            elif preset_worker_num == 1:
                # XXX: use_remote_env always be false when ray_mode is False, considering use naive SubProc.
                self._evaluation_pool = rollout.Evaluator(
                    env_desc, n_env, use_remote_env=rollout_config.remote_env
                )
        else:
            self._evaluation_worker_num = 1
            self._evaluation_pool = rollout.Evaluator(
                env_desc, n_env, use_remote_env=ray_mode and rollout_config.remote_env
            )

        self._ray_mode = ray_mode
        self._enable_evaluation_pool = enable_evaluation_pool
        self._enable_policy_pool = enable_policy_pool
        self._policies: Dict[AgentID, Union[PolicyPool, Policy]] = {}
        self._agent_interfaces: Dict[AgentID, AgentInterface] = {}
        self._ppool_config = policy_pool_config
        self.stopper = None  # avoid save error

        if enable_policy_pool:
            assert (
                policy_pool_config is not None
            ), "Policy pool enabled, but no avaiable policy pool configs given."
            self._use_learnable_dist = policy_pool_config.get(
                "use_learnable_dist", False
            )

    def reset_summary_writer(self, path=None, set_none=False):
        if set_none:
            self.summary_writer = None
        else:
            self.summary_writer = tensorboard.SummaryWriter(log_dir=path)

    @property
    def ray_mode(self) -> bool:
        return self._ray_mode

    @property
    def agents(self) -> Sequence[AgentID]:
        """Return a sequence of environment agents.

        Returns:
            Sequence[AgentID]: A sequence of agent ids.
        """

        return tuple(self._agents)

    def get_learner_state(self) -> LearnerState:
        """Return an instance of `LeanerState`.

        Returns:
            LearnerState: A named tuple `LearnerState(total_timesteps, total_episodes, global_step)`.
        """

        return LearnerState(
            self._total_timesteps, self._total_episodes, self._global_step
        )

    def update_learner_state(self, total_timesteps: int, total_episodes: int):
        """Update learner states

        Args:
            total_timesteps (int): Total timesteps.
            total_episodes (int): Total episodes
        """
        assert total_episodes >= 0
        self._total_episodes = total_episodes
        assert total_timesteps >= 0
        self._total_timesteps = total_timesteps

    def register_env_agents(self, agents: List[AgentID]):
        """Will set values to `self._agents`.

        Args:
            agents (Sequence[str]): A sequence of agent ids.
        """

        self._agents = agents

    def distill(self, meta_strategies: Dict[AgentID, Dict[PolicyID, float]]):
        """Distill policy from policy pool with given meta-strategies.

        Args:
            meta_strategies (Dict[AgentID, Dict[PolicyID, float]]): A dict of agent meta-strategies.
        """

        raise NotImplementedError

    def register_ego_agents(self, ego_agents):
        """Register ego environment agents to the learner.

        Args:
            ego_agents (List[AgentID]): A list of ego agents.
        """

        self._ego_agents = copy.deepcopy(ego_agents)

    @classmethod
    def as_remote(
        cls,
        num_cpus: int = None,
        num_gpus: int = None,
        memory: int = None,
        object_store_memory: int = None,
        resources: dict = None,
    ) -> type:
        """Return a remote class for Actor initialization"""

        return ray.remote(
            num_cpus=num_cpus,
            num_gpus=num_gpus,
            memory=memory,
            object_store_memory=object_store_memory,
            resources=resources,
        )(cls)

    @property
    def agent_mapping(self) -> LambdaType:
        """Return a lambda expression of agent mapping. For retrieving runtime ids of
        environment agents.

        Returns:
            LambdaType: A lambda expression of agent mapping.
        """

        return self._agent_mapping

    @property
    def ego_agents(self) -> List[AgentID]:
        """Return a list of ego agent id. It is actually a subset of the full set of
        environment agent ids.

        Returns:
            List[AgentID]: A list of ego agent id.
        """

        return self._ego_agents

    @property
    def num_support(self) -> Dict[AgentID, int]:
        return {
            agent: len(self._policies[self._agent_mapping(agent)])
            for agent in self.ego_agents
        }

    def get_num_support(self) -> Dict[AgentID, int]:
        """Return a dict of agents' support number, which maps agent ids to int.

        Note:
            The keys in this dict is a mapped results.

        Returns:
            Dict[AgentID, int]: A dict of agents' support number.
        """

        return self.num_support

    def sync_policies(
        self,
        policies: Dict[AgentID, Policy],
        force_sync: bool = False,
        sync_type: str = None,
        hard: bool = False,
    ) -> None:
        """Sync fixed agent policies from external.
        hard (bool): Hard sync. Defaults to False.
        """

        ppid_and_policies = []
        for agent in policies.keys():
            is_ego = agent in self._ego_agents
            if not is_ego or force_sync:
                ppid_and_policies.append(
                    (self._agent_mapping(agent), policies[agent], is_ego)
                )

        Log.debug(
            "sync fixed policies for agent=%s",
            [(k, list(v.keys())) for (k, v, _) in ppid_and_policies],
        )
        # for agent, _policies in policies.items():
        #     if agent in self._ego_agents:
        #         # print("\t* jump ego agent with policy", agent, _policies)
        #         continue
        #     ppid = self._agent_mapping(agent)
        for (ppid, _policies, is_ego) in ppid_and_policies:
            self._policies[ppid].sync_policies(
                _policies,
                hard=hard,
                sync_type=sync_type,
                replace_old=force_sync,
                is_ego=is_ego,
            )
            # for pid, policy in _policies.items():
            #     self._policies[ppid].add_policy(
            #         key=pid, policy=policy, fixed=policy.is_fixed
            #     )
            # print("\t- sync policy for agent:", policy, self._policies[agent], policy.is_fixed, agent)

    def get_dist(self):
        res = {}
        for aid in self._agents:
            if aid not in self._ego_agents:
                ppid = self._agent_mapping(aid)
                policy = self._policies[ppid]
                assert isinstance(policy, PolicyPool)
                res[aid] = policy._distribution.dict_values()
        return res

    def get_all_policies(self, get_fixed=True, get_active=True):
        res = {}
        for agent in self._agents:
            ppid = self._agent_mapping(agent)
            res[agent] = {}
            if get_fixed:
                res[agent].update(self._policies[ppid].get_fixed_policies())
            if get_active:
                res[agent].update(self._policies[ppid].get_active_policies())
        return res

    def get_ego_fixed_policy_ids(self):
        res = {}
        for agent in self._ego_agents:
            ppid = self._agent_mapping(agent)
            res[agent] = self._policies[ppid].get_fixed_policy_ids()
        return res

    def get_ego_fixed_policies(
        self, device="cpu"
    ) -> Dict[AgentID, Dict[PolicyID, Policy]]:
        """Returns a dict of dict of fixed policies.

        Returns:
            Dict[AgentID, Dict[PolicyID, Policy]]: A dict of mapping agent to a dict of fixied policies.
        """

        res = {}
        Log.debug("------- got ego fixed, {}".format(device))
        for agent in self._ego_agents:
            ppid = self._agent_mapping(agent)
            res[agent] = self._policies[ppid].get_fixed_policies(device)
        return res

    def get_ego_active_policies(
        self, device="cpu"
    ) -> Dict[AgentID, Dict[PolicyID, Policy]]:
        res = {}
        for agent in self._ego_agents:
            ppid = self._agent_mapping(agent)
            if res.get(ppid) is None:
                res[ppid] = self._policies[ppid].get_active_policies(device)
        return res

    def get_ego_active_policy_ids(self) -> Dict[AgentID, Sequence[PolicyID]]:
        res = {}
        for agent in self._ego_agents:
            ppid = self._agent_mapping(agent)
            if res.get(ppid) is None:
                res[ppid] = self._policies[ppid].get_active_policy_ids()
        return res

    def get_ego_policies(self) -> Dict[AgentID, Dict[PolicyID, Policy]]:
        res = {}
        for agent in self._ego_agents:
            ppid = self._agent_mapping(agent)
            res[agent] = self._policies[ppid].get_active_policies()
        return res

    def set_ego_policy_fixed(
        self, agent_policy_ids: Dict[AgentID, Sequence[PolicyID]] = None
    ):
        if agent_policy_ids is None:
            Log.warning(
                "Active policies does not be specified, will fix all active policies."
            )
        agent_policy_ids = agent_policy_ids or self.get_ego_active_policy_ids()
        for agent, pids in agent_policy_ids.items():
            if agent in self._ego_agents:
                ppid = self._agent_mapping(agent)
                self._policies[ppid].set_fixed(pids)

    def set_behavior_policies(self, agent_policy_mapping: Dict[AgentID, PolicyID]):
        """Set behavior policies for agent interfaces.

        Args:
            agent_policy_mapping (Dict[AgentID, PolicyID]): [description]

        Raises:
            NotImplementedError: [description]
        """

        # exclude ego agents
        raise NotImplementedError

    def set_behavior_dist(
        self, agent_policy_dist: Dict[AgentID, Dict[PolicyID, float]]
    ) -> None:
        """Set behavior distribution for fixed policy pools.

        Args:
            agent_policy_dist (Dict[AgentID, Dict[PolicyID, float]]): The dict of policy distribution.
        """

        for i, agent in enumerate(self._agents):
            if agent in self._ego_agents:
                continue
            ppid = self._agent_mapping(agent)
            if not self._ppool_config.get("distill_mode", False):
                self._policies[ppid].set_distribution(agent_policy_dist[ppid])
            else:
                opponent_dist = [
                    v for k, v in agent_policy_dist.items() if k not in self._ego_agents
                ]
                # then aggregate them
                tmp = defaultdict(lambda: 1.0)
                for _odist in opponent_dist:
                    for k, v in _odist.items():
                        tmp[k] *= v
                tmp = dict(tmp)
                self._policies[ppid].set_distribution(tmp)

    def add_policy(self, fixed: bool = False, n_support: int = 1, auto_bkup=True):
        """Add policies for ego agents.

        Args:
            agent_ids (Sequence[AgentID]): A sequence of agents which should be assign more policies to.
        """

        # call trainable agents to create new policies
        ppids = set([self._agent_mapping(agent) for agent in self._ego_agents])
        _ = [
            self._policies[ppid].add_policy(fixed=fixed)
            for _ in range(n_support)
            for ppid in ppids
        ]
        # and then reset trainer with the top policy
        #   policies in external loop, the `learn` function.
        # self._trainer[agent].reset(backups[0][1])

        if auto_bkup:
            self.bkup_ego_policies()

    def bkup_ego_policies(self):
        self.policies_bkups = []
        agent_policy_ids = self.get_ego_active_policy_ids()
        for agent, pids in agent_policy_ids.items():
            if agent in self._ego_agents:
                ppid = self._agent_mapping(agent)
                for pid in pids:
                    sd = self._policies[ppid].get_policies()[pid].state_dict()
                    self.policies_bkups.append((ppid, pid, sd))

    # def recover_ego_policies(self):
    #     for ppid, pid, sd in self.policies_bkups:
    #         self._policies[ppid].get_policies()[pid].load_state_dict(sd)

    def _task_split(
        self, task_desc: Union[List, int], min_load_length: int, task_type: str
    ):
        if task_type == "policy_mappings":
            # load balancing: at most 1+1 policy mapping simulation for each worker
            # XXX: mak sure item in task_desc is a dict
            min_load_length = 1
            policy_mappings = task_desc
            _length = len(policy_mappings) // self._evaluation_worker_num
            if _length < min_load_length:
                used_workernum = max(1, len(policy_mappings) // min_load_length)
                _length = len(policy_mappings) // used_workernum
                _tail_length = len(policy_mappings) % used_workernum
            else:
                _tail_length = len(policy_mappings) % self._evaluation_worker_num
                used_workernum = self._evaluation_worker_num

            segments = [_length] * used_workernum if _length else [0]
            segments[-1] += _tail_length

            policy_mapping_segs = []
            for e in segments:
                policy_mapping_segs.append(policy_mappings[:e])
                policy_mappings = policy_mappings[e:]
            return policy_mapping_segs
        elif task_type == "n_simulations":
            fragment_length = task_desc
            assert min_load_length > 0, min_load_length
            # split simulation num:
            _length = fragment_length // self._evaluation_worker_num
            if _length < min_load_length:
                used_workernum = max(1, fragment_length // min_load_length)
                _length = fragment_length // used_workernum
                _tail_length = fragment_length % used_workernum
            else:
                _tail_length = fragment_length % self._evaluation_worker_num
                used_workernum = self._evaluation_worker_num

            segments = [_length] * used_workernum if _length else [0]
            segments[-1] += _tail_length
            return segments
        else:
            raise ValueError(
                "Unknowe task desc: {}, task_type: {}".format(task_desc, task_type)
            )

    def evaluation(
        self,
        policy_mappings: Union[Sequence[Dict[AgentID, PolicyID]], None],
        max_step: int,
        fragment_length: int,
        max_episode: int = 10,
    ) -> Sequence[Tuple[Dict, Dict]]:
        """Run evaluation. If a sequence of poilcy mapping is given, it will be used to reset behavior policies and run evaluation of it.

        Args:
            policy_mappings (Sequence[Dict[AgentID, PolicyID]]): A sequence of policy mapping (a dict of policy ids).

        Returns:
            Sequence[Tuple[Dict, Dict]]: A sequence of (policy mapping, reward dict) tuples.
        """
        # if policy_mappings is not None:

        res = []
        if policy_mappings is not None:
            total_num = len(policy_mappings) if policy_mappings is not None else 0
            Log.debug("Start evaluation for %s policy mappings", total_num)
            task_type = "policy_mappings"
            tasks = self._task_split(policy_mappings, 1, task_type=task_type)
        else:
            task_type = "n_simulations"
            tasks = self._task_split(fragment_length, max_step, task_type=task_type)
        if (
            self._enable_evaluation_pool
            and self.ray_mode
            and self._evaluation_worker_num > 1
        ):
            rets = self._evaluation_pool.map(
                lambda a, t: a.run.remote(
                    policy_mappings=t if task_type == "policy_mappings" else None,
                    max_step=max_step,
                    fragment_length=t
                    if task_type == "n_simulations"
                    else fragment_length,
                    agent_interfaces=self._agent_interfaces,
                    rollout_caller=self._rollout_config.caller,
                    max_episode=max_episode,
                ),
                tasks,
            )
        else:
            rets = [
                self._evaluation_pool.run(
                    policy_mappings=t if task_type == "policy_mappings" else None,
                    max_step=max_step,
                    fragment_length=t
                    if task_type == "n_simulations"
                    else fragment_length,
                    agent_interfaces=self._agent_interfaces,
                    rollout_caller=self._rollout_config.caller,
                    max_episode=max_episode,
                )
                for t in tasks
            ]
        for ret in rets:
            res.extend(ret)
        Log.debug("eval res: {}".format(res))
        return res

    def reset(self):
        """Reset learner state here"""

        self._total_episodes = 0
        self._total_timesteps = 0
        self._global_step = 0

    @abstractmethod
    def learn(self, *args, **kwargs) -> Dict:
        pass

    @abstractmethod
    def save(self, **kwargs) -> None:
        pass

    @abstractmethod
    def load(self, **kwargs) -> None:
        pass
