"""
An implementation of independent learning schema.
"""

import os
import time
import ray

from expground import settings
from expground.types import (
    TrainingConfig,
    RolloutConfig,
    LearningMode,
    PolicyConfig,
    Dict,
    Any,
    Union,
    AgentID,
    LambdaType,
    Sequence,
    List,
    PolicyID,
)
from expground.logger import Log, monitor
from expground.utils.sampler import SamplerInterface, get_sampler
from expground.utils.stoppers import get_stopper, DEFAULT_STOP_CONDITIONS
from expground.utils import rollout

from expground.utils.logging import write_to_tensorboard
from expground.common.policy_pool import PolicyPool
from expground.envs.agent_interface import AgentInterface
from expground.envs import agent_interface, vector_env
from expground.learner.base_learner import Learner
from expground.algorithms.base_trainer import Trainer


class IndependentLearner(Learner):
    """Independent learner treats an algorithm as an agent/trainer"""

    def __init__(
        self,
        policy_config: PolicyConfig,
        env_description: Dict[str, Any],
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = True,
        train_every: int = 1,
        ego_agents: Sequence[AgentID] = None,
        experiment: str = None,
        seed: int = None,
        agent_mapping: LambdaType = lambda agent: agent,
        exp_config=None,
        mini_epoch: int = 5,
        # ==============================================
        ray_mode: bool = False,
        enable_policy_pool: bool = False,
        summary_writer: object = None,
        turn_off_logging: bool = False,
        policy_pool_config: Dict[str, Any] = None,
        resource_config: Dict[str, Any] = None,
        custom_config: Dict[str, Any] = None,
        reset_every_learn: bool = True,
        inner_eval: bool = True,
        pretrain_mode: bool = False,
    ):
        """Create an indepenent learner instance.

        Args:
            exp_name (str): The experiment indicator.
            policy_config (PolicyConfig): The configuration for policy buidling.
            env_description (Dict[str, Any]): The description of building an environment.
            rollout_config (RolloutConfig): The `rollout` configuration.
            training_config (TrainingConfig): The training configuration.
            stopper_config (Dict[str, Any]): [description]
            learning_mode (str): Indicates learning `on_policy` or `off_policy`.
            ego_agents (Sequence[AgentID]): Trainable agents, if not specified, ego agents
                are all possible agents in environment. Default to None.
            enable_policy_pool (bool): Use policy pool for each agent or not.
            experiment (str): The experiment tag name. Default to None.
            kwargs (Dict): A dict of extra parameters for building policy pool. Mapping to policy pool's parameter space.
        """

        multi_to_single = custom_config.get("multi_to_single", False)
        use_vector_env = custom_config.get("use_vector_env", False)
        if multi_to_single:
            group = env_description["config"]["group"]
            agent_to_group = {}
            for k, agents in group.items():
                tmp = dict.fromkeys(agents, k)
                agent_to_group.update(tmp)
        else:
            agent_to_group = {
                aid: aid for aid in env_description["config"]["possible_agents"]
            }

        Log.info("Independent got recource config: {}".format(resource_config))
        Log.info(
            "use evaluation pool: {}".format(
                custom_config.get("enable_evaluation_pool", False)
            )
        )

        super(IndependentLearner, self).__init__(
            experiment or f"Independent_{policy_config.policy}_{time.time()}",
            env_desc=env_description,
            rollout_config=rollout_config,
            summary_writer=summary_writer,
            seed=seed,
            ray_mode=ray_mode,
            enable_policy_pool=enable_policy_pool,
            resource_config=resource_config,
            policy_pool_config=policy_pool_config,
            exp_config=exp_config,
            agent_mapping=lambda x: agent_to_group[x],
            enable_evaluation_pool=custom_config.get("enable_evaluation_pool", False),
        )

        self._policy_config = policy_config
        self._env_desc = env_description
        self._rollout_config = rollout_config
        self._learning_mode = learning_mode
        self._episode_training = episodic_training
        self._train_every = train_every
        self._mini_epoch = mini_epoch
        self._total_episodes = 0
        self._total_timesteps = 0
        self._turn_off_logging = turn_off_logging
        self._training_config = training_config
        self._multi_to_single = multi_to_single

        self.register_env_agents(env_description["config"]["possible_agents"])
        self.register_ego_agents(ego_agents or self.agents)

        self._reset_every_learn = reset_every_learn
        self._reset_once = False
        self._inner_eval = inner_eval
        self._pretrain_mode = pretrain_mode
        self._reset_sampler = False

        # policies is the trainable policies for ego_agents
        # XXX(): considering to collect these basic initialization operators into the base learner class.
        self._policies = {}
        self._ego_ppid = []
        for agent in self._ego_agents:
            ppid = self.agent_mapping(agent)
            if self._policies.get(ppid) is None:
                self._ego_ppid.append(ppid)
                # TODO: please checkout the consistency between policy_pool mode and single_policy mode
                if enable_policy_pool:
                    # init policy pool with empty pool by defaults.
                    self._policies[ppid] = PolicyPool(
                        agent,
                        policy_config.copy(agent),
                        start_fixed_support_num=policy_pool_config.get(
                            "start_fixed_support_num", 0
                        ),
                        start_active_support_num=policy_pool_config.get(
                            "start_active_support_num", 0
                        ),
                        is_fixed=True,
                    )
                else:
                    self._policies[ppid] = policy_config.new_policy_instance(agent)

        # for other agents we build them with behavior policies
        for agent in self._agents:
            if agent in self._ego_agents:
                pass
            ppid = self.agent_mapping(agent)
            # check whether it has overlap with trainable policies
            if self._policies.get(ppid) is None:
                if enable_policy_pool:
                    self._policies[ppid] = PolicyPool(
                        agent,
                        policy_config.copy(agent),
                        start_fixed_support_num=0,
                        start_active_support_num=0,
                        is_fixed=not policy_pool_config.get(
                            "use_learnable_dist", False
                        ),
                        mixed_at_every_step=policy_pool_config.get(
                            "mixed_at_every_step", False
                        ),
                        distribution_training_kwargs=policy_pool_config.get(
                            "distribution_training_kwargs", {}
                        ),
                    )
                else:
                    self._policies[ppid] = policy_config.new_policy_instance(agent)

        self._trainer = {}
        for aid in self._ego_agents:
            ppid = self.agent_mapping(aid)
            if self._trainer.get(ppid) is None:
                policy = self._policies[ppid]
                if enable_policy_pool:
                    # XXX: currently, we support only one policy
                    active_policies = policy.get_active_policies()
                    Log.debug(
                        "detected policy pool, got active policies: %s", active_policies
                    )
                    policy = (
                        list(active_policies.values())[0]
                        if len(active_policies)
                        else None
                    )
                self._trainer[ppid] = training_config.trainer_cls(
                    loss_func(),
                    policy_instance=policy,
                    training_config=training_config.hyper_params,
                )

                if self._pretrain_mode:
                    self._trainer[ppid].set_pretrain(self._pretrain_mode)

        self._agent_interfaces: Dict[AgentID, AgentInterface] = {}
        env_config = self._env_desc["config"]
        observation_spaces = env_config["observation_spaces"]
        action_spaces = env_config["action_spaces"]
        observation_adapter = env_config.get(
            "observation_adapter", agent_interface.DEFAULT_OBSERVATION_ADAPTER
        )
        action_adapter = env_config.get(
            "action_adapter", agent_interface.DEFAULT_ACTION_ADAPTER
        )
        for _aid in env_config["possible_agents"]:
            ppid = self.agent_mapping(_aid)
            # ego agent, is_active=True
            if self._agent_interfaces.get(ppid) is None:
                self._agent_interfaces[ppid] = AgentInterface(
                    policy_name="",
                    policy=self._policies[ppid],
                    observation_space=observation_spaces[_aid],
                    action_space=action_spaces[_aid],
                    observation_adapter=observation_adapter,
                    action_adapter=action_adapter,
                    is_active=_aid in self._ego_agents,
                )

        self.sampler_env = (
            vector_env.VectorEnv(
                env_description,
                num_envs=rollout_config.fragment_length // rollout_config.max_step,
            )
            if use_vector_env
            else env_description["creator"](**env_description["config"])
        )

    def sync_weights(self):
        """Sync weights from active policy to fixed policy, only for self-play with policy_pool_size=2."""
        opponent = None
        for agent in self.agents:
            if agent not in self.ego_agents:
                opponent = self.agent_mapping(agent)
                break
        if self._enable_policy_pool:
            ego = self._policies[self.agent_mapping(self.ego_agents[0])]
            ego.sync_weights()
            opponent = self._policies[opponent]
            opponent.add_policy(
                key="policy-0", policy=ego._policies["policy-0"], fixed=True
            )

    def distill(self, meta_strategies: Dict[AgentID, Dict[PolicyID, float]]):
        pass

    @monitor(enable_returns=False, enable_time=True)
    def train(
        self,
        env_desc: Dict[str, Any],
        agent_interfaces: Dict[AgentID, AgentInterface],
        trainer: Dict[AgentID, Trainer],
        sampler: SamplerInterface,
        rollout_config: RolloutConfig,
        behavior_policies=None,
    ) -> Sequence[Dict[AgentID, Dict]]:
        """Main logics for training agents independently. Policies will be trained in staged mode. i.e., training will be started after some
        episodes have been collected.

        Note:
            Currently, this function doesn't support dynamic training policy reset (features for PSRO-like).

        Args:
            env_desc (Dict[str, Any]): The environment description.
            agent_interfaces: (Dict[AgentID, AgentInterface]): A dict of agent interfaces.
            trainer (Dict[AgentID, Trainer]): The training instance, which is responsible for the training of all agents.
            sampler (SamplerInterface): The instance of sampler interface.
            policies (Dict[str, Union[PolicyOptimizeInterface, Policy]]): A dict of policies, mapping from agents to policies.
            rollout_config (RolloutConfig): The rollout configuration instance.

        Returns:
            Sequence[Dict[AgentID, Dict]]: A sequence of training statistic in agent dict.
        """

        # build `rollout` process
        behavior_policies = behavior_policies or {}
        # agent interfaces will be reset with behavior policies
        # TODO(): when learnable pool is enable, sampler should record rewards for them
        generator = rollout_config.caller(
            sampler=sampler,
            agent_policy_mapping=behavior_policies,  # agent_mapping,
            agent_interfaces=agent_interfaces,
            env_description=env_desc,
            fragment_length=rollout_config.fragment_length,
            max_step=rollout_config.max_step,
            episodic=self._episode_training,
            train_every=self._train_every,
            evaluate=False,
            env=self.sampler_env,
            max_episode=self._rollout_config.max_episode,
            agent_filter=[self.agent_mapping(aid) for aid in self._ego_agents],
        )
        epoch_training_statistic = []
        try:
            start_timesteps = self._total_timesteps
            while True:
                info = next(generator)
                # self._total_episodes += 1
                start_timesteps += info["timesteps"]
                if sampler.is_ready():
                    tmp: Dict[AgentID, Dict] = {}
                    for _aid, _trainer in trainer.items():
                        group = env_desc["config"].get("group")
                        if group:
                            tmp[_aid] = _trainer(
                                sampler,
                                agent_filter=group[_aid],
                                time_step=start_timesteps,
                            )
                        else:
                            tmp[_aid] = _trainer(
                                sampler, agent_filter=[_aid], time_step=start_timesteps
                            )
                        # print("* training tmp for {} is {}".format(_aid, tmp[_aid]))
                        tmp[_aid]["walltime"] = time.time()
                        tmp[_aid]["timesteps"] = start_timesteps
                    epoch_training_statistic.append(tmp)
        except StopIteration as e:
            info = e.value
            self._total_episodes += info["num_episode"]
            self._total_timesteps += info["total_timesteps"]

        if sampler.is_ready() and self._episode_training:
            # print(sampler.is_ready(), self._episode_training)
            epoch_training_statistic = [
                {
                    _aid: _trainer(
                        sampler,
                        agent_filter=[_aid],
                        time_step=self._total_timesteps,
                        n_inner_loop=self._mini_epoch,
                    )
                    for _aid, _trainer in trainer.items()
                }
            ]

        return epoch_training_statistic

    def remove_dominated_policies(self, pids_dict: Dict[AgentID, List[PolicyID]]):
        for aid, pids in pids_dict.items():
            if aid in self.ego_agents:
                policy = self._policies[aid]
                assert isinstance(policy, PolicyPool)
                policy.remove_dominated_policies(pids)

    def return_dominated_policies(self) -> Dict[AgentID, List[PolicyID]]:
        policy_ids: Dict[AgentID, List[PolicyID]] = {}

        for agent in self.agents:
            if agent not in self.ego_agents:
                policy = self._policies[agent]
                assert isinstance(policy, PolicyPool)
                ids = policy.return_dominated_policies()
                policy_ids[agent] = ids

        return policy_ids

    def reset(self):
        """Reset leaner state and policy trainer here."""

        super().reset()
        for aid, trainer in self._trainer.items():
            policy_or_pool = self._policies[aid]
            if self._enable_policy_pool:
                # FIXME(): we now support only one policy
                policy = list(policy_or_pool.get_active_policies().values())[0]
            else:
                policy = policy_or_pool
            trainer.reset(policy)

    def _set_pretrain(self, pmode):
        self._pretrain_mode = pmode
        for aid in self._ego_agents:
            ppid = self.agent_mapping(aid)
            self._trainer[ppid].set_pretrain(pmode)

    def reset_all(self):
        self._reset_once = False
        self._set_pretrain(True)

    def finish_pretrain(self):
        self._set_pretrain(False)
        super().reset()

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict = None,
    ):
        """Run main learning loop. Controlled by given stop conditions, but limited by max_iteration.

        Args:
            sampler_config: (Union[Dict, LambdaType]): The configuration for building sampler for each agent. If it is a lambda,
                then generates configuration by agent id for each agent; if it is a dict, will be shared to all agents.
            max_iteration (int, optional): The maximum of iteration. Defaults to -1.
            stop_conditions (Dict, optional): The stop conditions, for stopping control. Defaults to None.
        """
        if self._reset_every_learn or not self._reset_once:
            self._reset_once = True
            self.reset()
            self.sampler = get_sampler(self.ego_agents, sampler_config)
        # elif self._reset_sampler:
        #     self.sampler = get_sampler(self.ego_agents, sampler_config)
        #     self._reset_sampler = False
        # Log.info("reset indep learner")
        # Log.info("timesteps: %s", self._total_timesteps)
        start_timesteps = self._total_timesteps
        start_episodes = self._total_episodes
        # else:
        #     super().reset()
        sampler = self.sampler

        # build stopper with given stop conditions
        self.stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)

        self.stopper.reset()
        # TODO(): require further test! we need to make sure the evaluation here is running
        #   with `active_policy` for ego agents, and all possible behavior
        #   policies for opponent agents.
        if self._inner_eval:
            epoch_evaluation_statistic = self.evaluation(
                policy_mappings=None,
                max_step=self._rollout_config.max_step,
                fragment_length=min(20, self._rollout_config.num_simulation)
                * self._rollout_config.max_step,
                max_episode=20,
            )[0][1]

            write_to_tensorboard(
                self.summary_writer,
                epoch_evaluation_statistic,
                global_step=self._total_timesteps,
                prefix="evaluation",
            )

        while not self.stopper.is_terminal():
            start = time.time()
            epoch_training_statistics = self.train(
                env_desc=self._env_desc,
                agent_interfaces=self._agent_interfaces,
                trainer=self._trainer,
                sampler=sampler,
                rollout_config=self._rollout_config,
            )

            Log.debug("end for training: {}".format(time.time() - start))

            for x in epoch_training_statistics:
                write_to_tensorboard(
                    self.summary_writer,
                    x,
                    global_step=self._total_timesteps,
                    prefix="training",
                )
                self.stopper.step(
                    None,
                    x,
                    time_step=self._total_timesteps - start_timesteps,
                    episode_th=self._total_episodes - start_episodes,
                )

            # if len(epoch_training_statistics) > 0:
            if self._inner_eval:
                epoch_evaluation_statistic = self.evaluation(
                    policy_mappings=None,
                    max_episode=20,
                    max_step=self._rollout_config.max_step,
                    fragment_length=min(20, self._rollout_config.num_simulation)
                    * self._rollout_config.max_step,
                )[0][1]
                write_to_tensorboard(
                    self.summary_writer,
                    epoch_evaluation_statistic,
                    global_step=self._total_timesteps,
                    prefix="evaluation",
                )

            if self._learning_mode == LearningMode.ON_POLICY:
                sampler.clean()

        if not self._inner_eval:
            if not self._pretrain_mode:
                ppid = self.agent_mapping(self._ego_agents[0])
                Log.info("\t= Eps of current %s.", self._trainer[ppid].get_eps())
            return None

        if self._inner_eval:
            if "win_rate" in epoch_evaluation_statistic:
                ret = {"FinalWinRate": epoch_evaluation_statistic["win_rate"]}
            elif "reward" in epoch_evaluation_statistic:
                ret = {"FinalReward": epoch_evaluation_statistic["reward"]}
            else:
                ret = None
        else:
            ret = None
        return ret

    def save(self, data_dir: str = None, hard: bool = False):
        """Save model and buffer"""
        if data_dir is None:
            model_data_dir = self._exp_config.get_path("models")
        else:
            model_data_dir = os.path.join(data_dir, "models")
            if not os.path.exists(model_data_dir):
                os.makedirs(model_data_dir)
        for ppid, policy in self._policies.items():
            if ppid not in self._ego_ppid:
                continue
            path = os.path.join(model_data_dir, ppid)
            if not os.path.exists(path):
                os.makedirs(path)
            Log.info("\t* save mode for {} to directory: {}".format(ppid, path))
            policy.save(
                path, global_step=self.get_learner_state().global_step, hard=hard
            )

    def load(self, data_dir: str = None, global_step=0):
        """Load model and dataset from local backup"""
        if data_dir is None:
            model_data_dir = self._exp_config.get_path("models")
        else:
            model_data_dir = os.path.join(data_dir, "models")
            if not os.path.exists(model_data_dir):
                os.makedirs(model_data_dir)

        for _aid, policy in self._policies.items():
            path = os.path.join(model_data_dir, f"{_aid}_{global_step}")
            policy.load(path)

        self._global_iteration = global_step
