"""
Learner for self-play, consider to merge it into PSRO learner in the future.
"""

import ray
import time
import random
import numpy as np

from collections import ChainMap, namedtuple

from expground.types import (
    RolloutConfig,
    TrainingConfig,
    PolicyConfig,
    EnvDescription,
    Dict,
    AgentID,
    Union,
    LambdaType,
    List,
    Any,
)
from expground.logger import Log
from expground.utils.logging import write_to_tensorboard
from expground.utils.stoppers import get_stopper, DEFAULT_STOP_CONDITIONS

from expground.learner.base_learner import Learner
from expground.learner.independent import IndependentLearner
from expground.learner.centralized import CentralizedLearner
from expground.learner.utils import dispatch_resources_for_agent
from expground.common.exploitability import measure_exploitability


NashConv = namedtuple("NashConv", "nash_conv,player_improvements")


class SelfPlay(Learner):
    def __init__(
        self,
        policy_config: PolicyConfig,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = False,
        train_every: int = 1,
        experiment: str = None,
        ray_mode: bool = False,
        seed: int = None,
        independent_learning: bool = True,
        evaluation_worker_num: int = 0,
        centralized_critic_config: Dict = None,
        use_cuda: bool = False,
        multi_to_single: bool = False,
        resource_limit: Dict = None,
        single_population: bool = False,
        measure_exp: bool = False,
        exp_config: Dict[str, Any] = None,
    ):
        """Initialize a Self-Play learner.

        Args:
            meta_solver (str): The type of meta solver, support `episodic` or `stepping` mode.
            policy_config (PolicyConfig): The configuration of policy.
            env_description (EnvDescription): The description of environment.
            rollout_config (RolloutConfig): The configuration of rollout.
            training_config (TrainingConfig): The configuration of training.
            loss_func (type): Loss function type.
            learning_mode (str): The learning mode, could be `on_policy` or `off_policy`, will effect the behavior of sampler,
                when `on_policy`, the sampler will clear its buffer at the end of each training itration; when `off_policy`, the
                sampler will maintain a buffer to store history experiences.
            use_learnable_dist (bool, optional): Fixed opponent policy distribution or not. Defaults to False.
            experiment (str, optional): Experiment tag name. Defaults to None.
        """

        inner_learner_type = "independent" if independent_learning else "centralized"
        Learner.__init__(
            self,
            experiment=f"{inner_learner_type}_{experiment}"
            or f"SP_{policy_config.human_readable or policy_config.policy}_{time.time()}",
            env_desc=env_description,
            rollout_config=rollout_config,
            seed=seed,
            ray_mode=ray_mode,
            resource_config={"evaluation_worker_mode": 0},
            exp_config=exp_config,
        )

        np.set_printoptions(linewidth=999)

        groups = list(env_description["config"]["group"].keys())
        resource_limit = resource_limit or {}
        learner_resources = dispatch_resources_for_agent(
            groups, use_cuda=use_cuda, resource_limit=resource_limit
        )

        # build agent learners with given policies from policy pool
        if independent_learning:
            learner_cls = IndependentLearner
        else:
            learner_cls = CentralizedLearner

        if ray_mode:
            learner_cls = learner_cls.as_remote(**learner_resources).remote

        policy_pool_config = {
            "mixed_at_every_step": False,
            "use_learnable_dist": False,
            "distribution_training_kwargs": {},
            "start_fixed_support_num": 1,
            "start_active_support_num": 0,
            "seed": seed,
            "distill_mode": False,  # always false
        }

        # only evaluation resource configuration yet.
        # here resource is not equal to computing resources
        resource_config = {
            "evaluation_worker_num": evaluation_worker_num,  # always mute, now, evaluation_worker_num,
        }

        sub_learner_custom_coonfig = (
            {
                "multi_to_single": multi_to_single,
                "use_vector_env": rollout_config.vector_mode,
            }
            if independent_learning
            else {
                "centralized_critic_config": centralized_critic_config,
                "groups": None,  # none-specific groups for sub learners
                "share_critic": False,
                "multi_to_single": multi_to_single,
            }
        )
        self._env_description = env_description
        self._rollout_config = rollout_config
        self._single_population = single_population
        self._measure_exp = measure_exp

        self._agent_learners = self.init_agent_learners(
            learner_cls=learner_cls,
            policy_config=policy_config,
            env_description=env_description,
            rollout_config=rollout_config,
            training_config=training_config,
            loss_func=loss_func,
            learning_mode=learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            seed=seed,
            possible_agents=env_description["config"]["possible_agents"],
            enable_policy_pool=True,
            policy_pool_config=policy_pool_config,
            custom_config=sub_learner_custom_coonfig,
            resource_config=resource_config,
            multi_to_single=multi_to_single,
            # symmetric=symmetric,
        )

        self.register_env_agents(env_description["config"]["possible_agents"])

        self._sync_up()

    def init_agent_learners(
        self,
        learner_cls: type,
        policy_config,
        env_description,
        rollout_config,
        training_config,
        loss_func,
        learning_mode,
        episodic_training,
        train_every,
        seed: int,
        possible_agents: List[AgentID],
        enable_policy_pool: bool,
        policy_pool_config: Dict[str, Any],
        custom_config: Dict[str, Any],
        resource_config: Dict[str, Any],
        multi_to_single: bool,
    ) -> Dict[AgentID, Learner]:

        if multi_to_single:
            assert "group" in env_description["config"]
            learner_ids = list(env_description["config"]["group"].keys())
            group = env_description["config"]["group"]
        else:
            learner_ids = possible_agents
            group = None

        if self._single_population:
            self.select_learners = [random.choice(learner_ids)]
            Log.info(
                "Single population mode, select learner={} as training case".format(
                    self.select_learners[0]
                )
            )
        else:
            self.select_learners = learner_ids

        Log.info("selected learners: {}".format(self.select_learners))

        agent_learners = {}
        for lid in learner_ids:
            # real_lid = self._symmetric_map(lid)
            # if agent_learners.get(real_lid) is None:
            if lid in self.select_learners:
                policy_pool_config["start_active_support_num"] = 1
            else:
                policy_pool_config["start_active_support_num"] = 0

            agent_learners[lid] = learner_cls(
                policy_config=policy_config,
                env_description=env_description,
                rollout_config=rollout_config,
                training_config=training_config,
                loss_func=loss_func,
                learning_mode=learning_mode,
                episodic_training=episodic_training,
                train_every=train_every,
                ego_agents=[lid] if group is None else group[lid],
                experiment=f"{self.experiment_tag}/{lid}",
                seed=seed,
                ray_mode=self.ray_mode,
                enable_policy_pool=enable_policy_pool,
                summary_writer=self.summary_writer if not self.ray_mode else None,
                turn_off_logging=True,
                policy_pool_config=policy_pool_config,
                resource_config=resource_config,
                custom_config=custom_config,
                exp_config=self._exp_config,
            )

        return agent_learners

    def _sync_up(self):
        """Sync fixed policies to learners, and switch active and fixed policies' weights"""

        # sync weights, only support for self-play mode
        # if sync_weights:
        Log.info("\t* sync weights")
        if self.ray_mode:
            _ = ray.get(
                [
                    self._agent_learners[k].sync_weights.remote()
                    for k in self.select_learners
                ]
            )
        else:
            _ = [self._agent_learners[k].sync_weights() for k in self.select_learners]

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict,
    ):
        # 1. generate simultations
        # training object should be an algorithm, not agent individual
        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()

        select_learners = self.select_learners

        while not stopper.is_terminal():
            Log.info("Global iteration on: %s", stopper.counter)
            # TODO: get latest fixed policy here
            if self._measure_exp:
                if self.ray_mode:
                    res = ray.get(
                        [
                            self._agent_learners[lid].get_ego_fixed_policies.remote(
                                "cpu"
                            )
                            for lid in self.select_learners
                        ]
                    )
                else:
                    res = [
                        self._agent_learners[lid].get_ego_fixed_policies("cpu")
                        for lid in self.select_learners
                    ]
                _fixed_policies = list(res[0].values())[0]
                fixed_policies = dict.fromkeys(self.agents, _fixed_policies)
                nash_conv = measure_exploitability(
                    self._env_description["config"]["env_id"],
                    populations=fixed_policies,
                    policy_mixture_dict=dict.fromkeys(
                        ["player_0", "player_1"], {"policy-0": 1.0}
                    ),
                )
                Log.info("\t* nash conv: %s", nash_conv)
                write_to_tensorboard(
                    self.summary_writer,
                    info={"NashConv": nash_conv.nash_conv},
                    global_step=stopper.counter,
                    prefix="",
                )

            active_policy_ids = {}

            if self.ray_mode:
                res = ray.get(
                    [
                        self._agent_learners[k].get_ego_active_policy_ids.remote()
                        for k in select_learners
                    ]
                )
                for e in res:
                    active_policy_ids.update(e)
            else:
                for k in select_learners:
                    learner = self._agent_learners[k]
                    active_policy_ids.update(learner.get_ego_active_policy_ids())
            # share policy id
            Log.info("\t* current active policies %s", active_policy_ids)
            # 5. learn policy supports, untill converge
            learning_res = {}
            if self.ray_mode:
                ray.get(
                    [
                        self._agent_learners[k].learn.remote(
                            sampler_config, stop_conditions=inner_stop_conditions
                        )
                        for k in select_learners
                    ]
                )
            else:
                for aid in select_learners:
                    learner = self._agent_learners[k]
                    learning_res[aid] = learner.learn(
                        sampler_config, stop_conditions=inner_stop_conditions
                    )

            write_to_tensorboard(
                self.summary_writer,
                info={"LearnerRes": learning_res},
                global_step=stopper.counter,
                prefix="",
            )
            # 6. sync policies
            self._sync_up()
            Log.info("\t* active policies trained and synced up")
            Log.info("")
            self.save(hard=True)
            # stopper is used to compute exploitability
            stopper.step(None, None, None, None)

    def save(self, data_dir: str = None, hard: bool = False):
        """Save model"""

        # data_dir = data_dir or self.state_dir
        if self.ray_mode:
            ray.get(
                [
                    learner.save.remote(hard=hard)
                    for aid, learner in self._agent_learners.items()
                ]
            )
        else:
            _ = [
                learner.save(hard=hard) for aid, learner in self._agent_learners.items()
            ]

    def load(self, checkpoint_dir):
        """Load model"""
        # for pool_id, pool in self._policy_pool.items():
        #     dir_path = os.path.join(checkpoint_dir, f"policy_pool/{pool_id}")
        #     pool.load.remote(dir_path)
        pass

    # def _run_decentralized_meta_solver(
    #     self,
    # ) -> Tuple[
    #     Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]]
    # ]:
    #     fixed_policies = {}
    #     if self.ray_mode:
    #         res = ray.get(
    #             [
    #                 learner.get_ego_fixed_policies.remote()
    #                 for learner in self._agent_learners.values()
    #             ]
    #         )
    #         for e in res:
    #             fixed_policies.update(e)
    #         res = ray.get(
    #             [learner.get_dist.remote() for learner in self._agent_learners.values()]
    #         )
    #     else:
    #         for agent, learner in self._agent_learners.items():
    #             fixed_policies.update(learner.get_ego_fixed_policies())
    #         res = [learner.get_dist() for learner in self._agent_learners.values()]

    #     equilibrium = {}
    #     for e in res:
    #         equilibrium.update(e)

    #     # share equilibrium and fixed policies
    #     policies = list(fixed_policies.values())[0]
    #     eq = list(equilibrium.values())[0]
    #     fixed_policies.update(
    #         {
    #             aid: policies
    #             for aid in self._env_description["config"]["possible_agents"]
    #         }
    #     )
    #     equilibrium.update(
    #         {aid: eq for aid in self._env_description["config"]["possible_agents"]}
    #     )
    #     return equilibrium, fixed_policies

    # def _run_centralized_meta_solver(
    #     self,
    # ) -> Tuple[
    #     Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]]
    # ]:
    #     simulations = self._payoff_matrix.gen_simulations(split=True)

    #     # 2. update payoff tables
    #     results = []
    #     if self.ray_mode:
    #         res = ray.get(
    #             [
    #                 learner.evaluation.remote(
    #                     simulation,
    #                     max_step=self._rollout_config.max_step,
    #                     fragment_length=self._rollout_config.num_simulation
    #                     * self._rollout_config.max_step,
    #                 )
    #                 for learner, simulation in zip(
    #                     self._agent_learners.values(), simulations
    #                 )
    #             ]
    #         )
    #         for e in res:
    #             results.extend(e)
    #     else:
    #         for learner, simulation in zip(self._agent_learners.values(), simulations):
    #             # return a sequence of (policy_mapping, reward_dict) tuples
    #             results.extend(
    #                 learner.evaluation(
    #                     simulation,
    #                     max_step=self._rollout_config.max_step,
    #                     fragment_length=self._rollout_config.num_simulation
    #                     * self._rollout_config.max_step,
    #                 )
    #             )
    #     Log.info("\t* evaluation finished for %s simulations", len(results))
    #     self._payoff_matrix.update_payoff_and_simulation_status(results)
    #     # got a full of policy mapping
    #     fixed_policies = {}
    #     if self.ray_mode:
    #         res = ray.get(
    #             [
    #                 learner.get_ego_fixed_policies.remote()
    #                 for learner in self._agent_learners.values()
    #             ]
    #         )
    #         for e in res:
    #             fixed_policies.update(e)
    #     else:
    #         for agent, learner in self._agent_learners.items():
    #             fixed_policies.update(learner.get_ego_fixed_policies())
    #     # a dict of payoff tables
    #     policies = list(fixed_policies.values())[0]
    #     pids = list(policies.keys())
    #     fixed_policies.update(
    #         {
    #             aid: policies
    #             for aid in self._env_description["config"]["possible_agents"]
    #         }
    #     )
    #     fixed_policy_mapping = {
    #         aid: pids for aid in self._env_description["config"]["possible_agents"]
    #     }

    #     matrix = self._payoff_matrix.get_sub_matrix(fixed_policy_mapping)
    #     equilibrium = self._meta_solver.solve(matrix)
    #     # merge with poicy
    #     equilibrium = {
    #         k: dict(
    #             zip(
    #                 fixed_policy_mapping[k],
    #                 equilibrium[k],
    #             )
    #         )
    #         for k, v in matrix.items()
    #     }
    #     return equilibrium, fixed_policies
