import time
import collections
import ray
import os

import numpy as np

from expground import settings
from expground.types import (
    List,
    AgentID,
    PolicyConfig,
    PolicyID,
    TrainingConfig,
    RolloutConfig,
    EnvDescription,
    LambdaType,
    Dict,
    Union,
    Tuple,
    Any,
    Sequence,
)
from expground.logger import Log
from expground.utils.logging import write_to_tensorboard, append_to_table
from expground.gt.payoff_matrix import PayoffMatrix
from expground.gt.payoff_server import Identifier, PayoffServer
from expground.utils.stoppers import get_stopper, DEFAULT_STOP_CONDITIONS

from expground.learner.base_learner import Learner
from expground.learner.independent import IndependentLearner, SubOracle
from expground.learner.centralized import CentralizedLearner
from expground.learner.utils import (
    generate_random_from_shapes,
    dispatch_resources_for_agent,
)
from expground.gt import MetaSolver
from expground.common.exploitability import measure_exploitability
from expground.algorithms.base_policy import Policy


class PSROLearner(Learner):
    """Policy Space Response Oracle learner. This learner supports multi-agent learning with
    single-agent RL methods in PSRO learning manner. Nested with indendependent learner
    """

    NAME = "PSRO"

    def __init__(
        self,
        meta_solver: str,
        policy_config: PolicyConfig,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = False,
        train_every: int = 1,
        remove_dominated_policy: bool = False,
        use_learnable_dist: bool = False,
        experiment: str = None,
        ray_mode: bool = False,
        seed: int = None,
        mixed_at_every_step: bool = False,
        independent_learning: bool = True,
        distribution_training_kwargs: Dict = None,
        centralized_critic_config: Dict = None,
        rectifier_type: int = 0,
        mini_epoch: int = 1,
        exp_config=None,
        save_indep_curve_every=20,
        use_remote_payoff: bool = False,
        distill_mode: bool = False,
        resource_limit: Dict = None,
        use_cuda: bool = False,
        measure_exp: bool = True,
        evaluation_worker_total_cpu_num: int = 1,
    ):
        """Initialize a PSRO learner.

        Args:
            meta_solver (str): The type of meta solver, choices={fictitious_play, alpha_rank, mrcp}.
            policy_config (PolicyConfig): The configuration of policy.
            env_description (EnvDescription): The description of environment.
            rollout_config (RolloutConfig): The configuration of rollout.
            training_config (TrainingConfig): The configuration of training.
            loss_func (type): Loss function type.
            learning_mode (str): The learning mode, could be `on_policy` or `off_policy`, will effect the behavior of sampler,
                when `on_policy`, the sampler will clear its buffer at the end of each training itration; when `off_policy`, the
                sampler will maintain a buffer to store history experiences.
            use_learnable_dist (bool, optional): Fixed opponent policy distribution or not. Defaults to False.
            experiment (str, optional): Experiment tag name. Defaults to None.
            ray_mode (bool, optional): Enable ray mode or not. Defaults to False.
            seed: (int, optional): Random seed. Defaults to None.
            mixed_at_every_step (bool, optional): Sample policy at each timestep or the top of episode. Defaults to False.
            independent_learning (bool, optional): Discovery best response via independent learning or centralized learning. Defaults to True.
            evaluation_worker_num (int, optional): Specify evaluation workers, enabled when ray_mode is turned on. Defaults to 0.
            distribution_training_kwargs (Dict, optional): Training configuration for optimizing meta strategies. Defaults to None.
            centralized_critic_config (Dict, optional): Network configuration to construct a centralized critic. Defaults to None.
            rectifier_type (int, optional): Specify rectifier type. Defaults to 0. See `expground.utils.payoff_matrix.py::PayoffMatrix`.
            look_ahead (bool, optional): Discover best response via Bi-level optimization or not. Default to False.
        """

        inner_learner_type = "independent" if independent_learning else "centralized"
        if experiment is None:
            experiment = "psro"
        super(PSROLearner, self).__init__(
            # f"{inner_learner_type}_{experiment}"
            # or f"{self.NAME}_{policy_config.human_readable or policy_config.policy}_{time.time()}",
            experiment,
            seed=seed,
            ray_mode=ray_mode,
            exp_config=exp_config,
            resource_config={"evaluation_worker_num": 0},  # always mute
            rollout_config=rollout_config,
            env_desc=env_description,
        )

        np.set_printoptions(linewidth=999)

        possible_agents = env_description["config"]["possible_agents"]
        resource_limit = resource_limit or {}
        learner_resources = dispatch_resources_for_agent(
            possible_agents, use_cuda=use_cuda, resource_limit=resource_limit
        )

        # build agent learners with given policies from policy pool
        if independent_learning:
            if use_remote_payoff:
                learner_cls = SubOracle
            else:
                learner_cls = IndependentLearner
            if ray_mode:
                learner_cls = learner_cls.as_remote(**learner_resources).remote
        else:
            if ray_mode:
                learner_cls = CentralizedLearner.as_remote(**learner_resources).remote
            else:
                learner_cls = CentralizedLearner

        policy_pool_config = {
            "mixed_at_every_step": mixed_at_every_step,
            "use_learnable_dist": use_learnable_dist,
            "distribution_training_kwargs": distribution_training_kwargs,
            "start_fixed_support_num": 1,
            "start_active_support_num": 0,
            "seed": seed,
            "distill_mode": distill_mode,
        }

        # only evaluation resource configuration yet.
        # resource_config = {
        #     "evaluation_worker_num": evaluation_worker_num,  # always mute, now, evaluation_worker_num,
        #     "evaluation_worker_total_cpu_num": evaluation_worker_total_cpu_num
        # }

        sub_learner_custom_coonfig = (
            {
                "use_vector_env": rollout_config.vector_mode,
                "enable_evaluation_pool": True,
            }
            if independent_learning
            else {
                "centralized_critic_config": centralized_critic_config,
                "groups": None,  # none-specific groups for sub learners
                "share_critic": False,
            }
        )

        self._agent_learners = self.init_agent_learners(
            learner_cls=learner_cls,
            policy_config=policy_config,
            env_description=env_description,
            rollout_config=rollout_config,
            training_config=training_config,
            loss_func=loss_func,
            learning_mode=learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            seed=seed,
            possible_agents=possible_agents,
            enable_policy_pool=True,
            policy_pool_config=policy_pool_config,
            custom_config=sub_learner_custom_coonfig,
            resource_config=resource_limit,
            mini_epoch=mini_epoch,  # XXX: new added
            exp_config=exp_config,
        )

        # create meta solver
        self._meta_solver = MetaSolver.from_type(meta_solver)
        if use_remote_payoff:
            self._payoff_server = PayoffServer.options(
                name=settings.PAYOFF_SERVER_ACTOR
            ).remote(possible_agents, meta_solver, "async", rectifier_type)
            self._payoff_matrix = ray.get(
                self._payoff_server.get_payoff_matrix.remote()
            )
        else:
            self._payoff_server = None
            self._payoff_matrix = PayoffMatrix(
                tuple(possible_agents), rectifier_type=rectifier_type
            )
        self._env_description = env_description
        self._rollout_config = rollout_config
        self._use_learnable_dist = use_learnable_dist
        self._use_remote_payoff = use_remote_payoff
        self._save_indep_curve_every = save_indep_curve_every
        self._ppool_config = policy_pool_config
        self._measure_exp = measure_exp

        self.register_env_agents(possible_agents)

        # sync policies to learners
        self._remove_dominated_policy = remove_dominated_policy
        fixed_policy_ids = self._sync_up(0)
        self._payoff_matrix.expand(fixed_policy_ids)
        if use_remote_payoff:
            ray.get(self._payoff_server.set_payoff_matrix.remote(self._payoff_matrix))

    def init_agent_learners(
        self,
        learner_cls: type,
        policy_config,
        env_description,
        rollout_config,
        training_config,
        loss_func,
        learning_mode,
        episodic_training,
        train_every,
        seed: int,
        possible_agents: List[AgentID],
        enable_policy_pool: bool,
        policy_pool_config: Dict[str, Any],
        custom_config: Dict[str, Any],
        resource_config: Dict[str, Any],
        mini_epoch: int,
        exp_config: Dict[str, Any],
    ) -> Dict[AgentID, Learner]:
        agent_learners = {
            # initialize policy pool for each ego agent
            aid: learner_cls(
                policy_config=policy_config,
                env_description=env_description,
                rollout_config=rollout_config,
                training_config=training_config,
                loss_func=loss_func,
                learning_mode=learning_mode,
                episodic_training=episodic_training,
                train_every=train_every,
                ego_agents=[aid],
                experiment=f"{self.experiment_tag}/{aid}",
                seed=seed,
                ray_mode=self.ray_mode,
                exp_config=exp_config,
                enable_policy_pool=enable_policy_pool,
                summary_writer=self.summary_writer if not self.ray_mode else None,
                turn_off_logging=True,
                policy_pool_config=policy_pool_config,
                resource_config=resource_config,
                custom_config=custom_config,
                mini_epoch=mini_epoch,
                inner_eval=False,
            )
            for aid in possible_agents
        }
        return agent_learners

    @property
    def payoff_matrix(self) -> PayoffMatrix:
        return self._payoff_matrix

    @property
    def meta_solver(self) -> MetaSolver:
        return self._meta_solver

    @property
    def num_support(self) -> Dict[AgentID, int]:
        """Return the number of all policy supports.

        Returns:
            Dict[AgentID, int]: A dict maps environment agents to ints.
        """

        res = {}
        for learner in self._agent_learners.values():
            res.update(ray.get(learner.get_num_support.remote()))
        return res

    def add_policy(self, n_support: int = 1):
        tasks = []
        for learner in self._agent_learners.values():
            if self.ray_mode:
                tasks.append(learner.add_policy.remote(n_support))
            else:
                # TODO: consider to integrate methods like `add_policy`
                #   into a policy manager
                learner.add_policy(n_support)
        if self.ray_mode:
            ray.get(tasks)

    def _reset_indep_summary(self, round):
        set_none = True
        if (round + 1) % self._save_indep_curve_every == 0 or round == 0:
            set_none = False

        if self.ray_mode:
            ray_list = []
            for agent, learner in self._agent_learners.items():
                indep_log_path = ""
                if not set_none:
                    indep_log_path = self._exp_config.get_path(f"{agent}_{round}")
                ray_list.append(
                    learner.reset_summary_writer.remote(indep_log_path, set_none)
                )
            res = ray.get(ray_list)
        else:
            for agent, learner in self._agent_learners.items():
                indep_log_path = ""
                if not set_none:
                    indep_log_path = self._exp_config.get_path(f"{agent}_{round}")
                learner.reset_summary_writer(indep_log_path, set_none)

    def _update_agent_matrix(self, equilibrium):
        # equilibrium table
        for agent, learner in self._agent_learners.items():
            active_policy_ids = {}
            if self.ray_mode:
                active_policy_ids = ray.get(learner.get_ego_active_policy_ids.remote())

            else:
                active_policy_ids = learner.get_ego_active_policy_ids()
            # update equilibrium
            assert len(active_policy_ids) == 1
            train_id = list(active_policy_ids.values())[0]
            f_path = os.path.join(self._exp_config.log_path, f"equilibrium_{agent}.csv")
            title = []
            pids = list(equilibrium[agent].keys())
            pids.sort()
            row = [equilibrium[agent][i] for i in pids]
            append_to_table(f_path, title, row)

            # payoff table
            f_path = os.path.join(self._exp_config.log_path, f"payoff_{agent}.csv")
            # row = [self._payoff_matrix.payoff_matrix[agent].get_agent_support_idx(i) for i in pids]
            title = []
            # agent = list(self._agent_learners.keys())[0]
            assert (
                len(self._payoff_matrix.payoff_matrix[agent].agents) == 2
            ), "only support agents=2 for now"
            if agent == "player_0":
                row = self._payoff_matrix.payoff_matrix[agent].table[-1].tolist()
            else:
                row = (
                    self._payoff_matrix.payoff_matrix[agent]
                    .table[:, -1]
                    .reshape([-1])
                    .tolist()
                )
            append_to_table(f_path, title, row)

    def _sync_up(self, step: int):
        """Sync fixed policies to learners, and expand payoff/simulation table."""

        if self._remove_dominated_policy and step > 0:
            # TODO(): remove dominated policy
            removed_pids: Dict[AgentID, List[PolicyID]] = {}
            if self.ray_mode:
                tmp = ray.get(
                    [
                        learner.return_dominated_policies.remote()
                        for learner in self._agent_learners.values()
                    ]
                )
                for x in tmp:
                    removed_pids.update(x)
                ray.get(
                    [
                        learner.remove_dominated_policies.remote(removed_pids)
                        for learner in self._agent_learners.values()
                    ]
                )
            else:
                tmp = [
                    learner.return_dominated_policies()
                    for learner in self._agent_learners.values()
                ]
                for x in tmp:
                    removed_pids.update(x)
                _ = [
                    learner.remove_dominated_policies(removed_pids)
                    for learner in self._agent_learners.values()
                ]

            Log.info("REMOVED DOMINATED POLICIES: {}".format(removed_pids))

        # sync policies to learners
        fixed_policies = {}
        if self.ray_mode:
            res = ray.get(
                [
                    learner.get_ego_fixed_policies.remote("cpu")
                    for learner in self._agent_learners.values()
                ]
            )
            for e in res:
                fixed_policies.update(e)
            ray.get(
                [
                    learner.sync_policies.remote(fixed_policies)
                    for learner in self._agent_learners.values()
                ]
            )
        else:
            for agent, learner in self._agent_learners.items():
                fixed_policies.update(learner.get_ego_fixed_policies("cpu"))
            for aid, learner in self._agent_learners.items():
                learner.sync_policies(fixed_policies)

        fixed_policy_ids = {k: list(v.keys()) for k, v in fixed_policies.items()}
        if self._use_remote_payoff:
            # ray.get(self._payoff_server.expand.remote(fixed_policy_ids))
            self._payoff_matrix = ray.get(
                self._payoff_server.get_payoff_matrix.remote()
            )
        # else:
        #     self._payoff_matrix.expand(fixed_policy_ids)
        return fixed_policy_ids

    def merge_best_response(
        self, meta_strategies: Dict[AgentID, Dict[PolicyID, float]]
    ):
        """Merge meta-strategies using single policy.

        Raises:
            NotImplementedError: [description]

        Returns:
            [type]: [description]
        """

        tids = list(self.agents)
        tasks = []
        for aid in tids:
            learner = self._agent_learners[aid]
            if self.ray_mode:
                tasks.append(learner.distill.remote(meta_strategies))
            else:
                tasks.append(learner.distill(meta_strategies))
        if self.ray_mode:
            results = ray.get(tasks)
        else:
            results = tasks

    # TODO(): may we can abstract this block as a BR learning, so that users can rewrite it.
    def _run_decentralized_meta_solver(
        self,
    ) -> Tuple[
        Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]]
    ]:
        """Run decentralized meta solver when the meta strategies are learnable. Note that the decentralized
        BR learning will not compute a tuple of meta-strategy via a centralized solver like NE, PRD, etc.

        Returns:
            Tuple[ Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]] ]: A tuple of equilibrium and fixed policies.
        """

        fixed_policies = {}
        if self.ray_mode:
            res = ray.get(
                [
                    learner.get_ego_fixed_policies.remote("cpu")
                    for learner in self._agent_learners.values()
                ]
            )
            for e in res:
                fixed_policies.update(e)
            res = ray.get(
                [learner.get_dist.remote() for learner in self._agent_learners.values()]
            )
        else:
            for agent, learner in self._agent_learners.items():
                fixed_policies.update(learner.get_ego_fixed_policies("cpu"))
            res = [learner.get_dist() for learner in self._agent_learners.values()]

        equilibrium = {}
        for e in res:
            equilibrium.update(e)

        return equilibrium, fixed_policies

    def _run_centralized_meta_solver(
        self,
    ) -> Tuple[
        Dict[AgentID, Dict[PolicyID, float]], Dict[AgentID, Dict[PolicyID, Policy]]
    ]:
        # gen simulation with given fixed policy set
        simulations = self._payoff_matrix.gen_simulations(split=True)

        # 2. update payoff tables
        results = []
        start = time.time()
        if self.ray_mode:
            res = ray.get(
                [
                    learner.evaluation.remote(
                        simulation,
                        max_step=self._rollout_config.max_step,
                        fragment_length=self._rollout_config.num_simulation
                        * self._rollout_config.max_step,
                        max_episode=self._rollout_config.num_simulation,
                    )
                    for learner, simulation in zip(
                        self._agent_learners.values(), simulations
                    )
                ]
            )
            for e in res:
                results.extend(e)
        else:
            for learner, simulation in zip(self._agent_learners.values(), simulations):
                # return a sequence of (policy_mapping, reward_dict) tuples
                results.extend(
                    learner.evaluation(
                        simulation,
                        max_step=self._rollout_config.max_step,
                        fragment_length=self._rollout_config.num_simulation
                        * self._rollout_config.max_step,
                        max_episode=self._rollout_config.num_simulation,
                    )
                )
        Log.info(
            "\t* evaluation finished for {} simulations, time={}".format(
                len(results), time.time() - start
            )
        )
        self._payoff_matrix.update_payoff_and_simulation_status(results)
        # got a full of policy mapping
        fixed_policies = {}
        if self.ray_mode:
            res = ray.get(
                [
                    learner.get_ego_fixed_policies.remote("cpu")
                    for learner in self._agent_learners.values()
                ]
            )
            for e in res:
                fixed_policies.update(e)
        else:
            for agent, learner in self._agent_learners.items():
                fixed_policies.update(learner.get_ego_fixed_policies("cpu"))
        # a dict of payoff tables
        fixed_policy_mapping = {
            aid: list(_policies.keys()) for aid, _policies in fixed_policies.items()
        }
        matrix = self._payoff_matrix.get_sub_matrix(fixed_policy_mapping)
        equilibrium = self._meta_solver.solve(matrix)
        # merge with poicy
        equilibrium = {
            k: dict(
                zip(
                    fixed_policy_mapping[k],
                    equilibrium[k],
                )
            )
            for k, v in matrix.items()
        }
        return equilibrium, fixed_policies

    def get_ego_active_policies(self) -> Dict[AgentID, Dict[PolicyID, Policy]]:
        res = {}
        if self.ray_mode:
            tmp = ray.get(
                [
                    learner.get_ego_active_policies.remote()
                    for learner in self._agent_learners.values()
                ]
            )
        else:
            tmp = [
                learner.get_ego_active_policies()
                for learner in self._agent_learners.values()
            ]
        res.update(collections.ChainMap(*tmp))
        return res

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict,
    ):
        # 1. generate simultations
        # training object should be an algorithm, not agent individual
        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()
        round = 0
        while not stopper.is_terminal():
            Log.info("Global iteration on: %s", stopper.counter)
            self._reset_indep_summary(round)
            if not self._use_learnable_dist:
                equilibrium, fixed_policies = self._run_centralized_meta_solver()
                Log.info("\t* init behavior dist: %s", equilibrium)
                if self.ray_mode:
                    ray.get(
                        [
                            learner.set_behavior_dist.remote(equilibrium)
                            for learner in self._agent_learners.values()
                        ]
                    )
                else:
                    for aid, learner in self._agent_learners.items():
                        learner.set_behavior_dist(equilibrium)
            else:
                equilibrium, fixed_policies = self._run_decentralized_meta_solver()

            self._update_agent_matrix(equilibrium)

            # 4. generate trainable policies (a dict of policy dict) {agent: {pid: policy}}
            if self.ray_mode:
                ray.get(
                    [
                        learner.add_policy.remote(n_support=1)
                        for learner in self._agent_learners.values()
                    ]
                )
            else:
                for agent, learner in self._agent_learners.items():
                    learner.add_policy(n_support=1)
            active_policy_ids = {}

            if self.ray_mode:
                res = ray.get(
                    [
                        learner.get_ego_active_policy_ids.remote()
                        for learner in self._agent_learners.values()
                    ]
                )
                for e in res:
                    active_policy_ids.update(e)
            else:
                for agent, learner in self._agent_learners.items():
                    active_policy_ids.update(learner.get_ego_active_policy_ids())
            assert len(active_policy_ids[self.agents[0]]) > 0, active_policy_ids
            self._payoff_matrix.expand(active_policy_ids)
            Log.info("\t* generated active policies %s", active_policy_ids)

            # and sync to remote server
            if self._use_remote_payoff:
                ray.get(
                    self._payoff_server.set_payoff_matrix.remote(self._payoff_matrix)
                )

            # 5. learn policy supports, untill converge
            learning_res = {}
            if self.ray_mode:
                tmp = ray.get(
                    [
                        learner.learn.remote(
                            sampler_config, stop_conditions=inner_stop_conditions
                        )
                        for learner in self._agent_learners.values()
                    ]
                )
                learning_res = dict(zip(list(self._agent_learners.keys()), tmp))
            else:
                for aid, learner in self._agent_learners.items():
                    learning_res[aid] = learner.learn(
                        sampler_config, stop_conditions=inner_stop_conditions
                    )
                    # then set fixed
            report = {}
            if self._measure_exp:
                if self._use_learnable_dist:
                    (
                        dpsro_equilibrium,
                        dpsro_fixed_policies,
                    ) = self._run_decentralized_meta_solver()

                    # TODO(): print as tabulate
                    Log.info("\t* computed equilibrium as: %s", dpsro_equilibrium)
                    dpsro_nash_conv = measure_exploitability(
                        self._env_description["config"]["env_id"],
                        dpsro_fixed_policies,
                        dpsro_equilibrium,
                    )
                    report = {"NashConv/dpsro": dpsro_nash_conv.nash_conv}
                    Log.info("\t* dpsro nash conv: %s", dpsro_nash_conv)
                else:
                    nash_conv = measure_exploitability(
                        self._env_description["config"]["env_id"],
                        fixed_policies,
                        equilibrium,
                    )
                    report = {"NashConv/base": nash_conv.nash_conv}
                    Log.info("\t* nash conv: %s", nash_conv)

            # report["LearnerRes"] = learning_res

            write_to_tensorboard(
                self.summary_writer,
                info=report,
                global_step=stopper.counter,
                prefix="",
            )

            # 7. set active policy to fixed:
            if self.ray_mode:
                ray.get(
                    [
                        learner.set_ego_policy_fixed.remote(active_policy_ids)
                        for learner in self._agent_learners.values()
                    ]
                )
            else:
                _ = [
                    learner.set_ego_policy_fixed(active_policy_ids)
                    for learner in self._agent_learners.values()
                ]

            # TODO(): remove dominated policies
            self._sync_up(stopper.counter)
            Log.info("\t* active policies trained and synced up")
            Log.info("")
            # stopper is used to compute exploitability
            stopper.step(None, None, None, None)
            self.save()
            round += 1

    def save(self, data_dir: str = None):
        """Save model"""

        data_dir = data_dir or self.state_dir
        # TODO(): save also the policy distribution
        if self.ray_mode:
            ray.get(
                [
                    learner.save.remote(hard=False)
                    for aid, learner in self._agent_learners.items()
                ]
            )
        else:
            _ = [
                learner.save(hard=False)
                for aid, learner in self._agent_learners.items()
            ]

    def load(self, checkpoint_dir):
        """Load model"""
        # for pool_id, pool in self._policy_pool.items():
        #     dir_path = os.path.join(checkpoint_dir, f"policy_pool/{pool_id}")
        #     pool.load.remote(dir_path)
        pass
