import itertools
import ray
import numpy as np

from expground.types import (
    AgentID,
    PolicyID,
    Any,
    Union,
    Tuple,
    Dict,
    LambdaType,
    Sequence,
    PolicyConfig,
    EnvDescription,
    RolloutConfig,
    TrainingConfig,
)
from expground.logger import Log
from expground.utils.stoppers import get_stopper, DEFAULT_STOP_CONDITIONS
from expground.utils.logging import write_to_tensorboard

from expground.common.exploitability import measure_exploitability
from expground.learner.psro.psro import PSROLearner
from expground.learner.utils import generate_random_from_shapes


class BPSROLearner(PSROLearner):
    def __init__(
        self,
        meta_solver: str,
        policy_config: PolicyConfig,
        env_description: EnvDescription,
        rollout_config: RolloutConfig,
        training_config: TrainingConfig,
        loss_func: type,
        learning_mode: str,
        episodic_training: bool = False,
        train_every: int = 1,
        use_learnable_dist: bool = False,
        experiment: str = None,
        ray_mode: bool = False,
        seed: int = None,
        mixed_at_every_step: bool = False,
        independent_learning: bool = True,
        evaluation_worker_num: int = 0,
        distribution_training_kwargs: Dict = None,
        centralized_critic_config: Dict = None,
        rectifier_type: int = 0,
        max_bilevel_step: int = 20,
        bi_solver_type: str = "linear",
        compare_mode: bool = False,
        mini_epoch: int = 1,
    ):
        if experiment is None:
            experiment = "bpsro"
        super(BPSROLearner, self).__init__(
            meta_solver,
            policy_config,
            env_description,
            rollout_config,
            training_config,
            loss_func,
            learning_mode,
            episodic_training=episodic_training,
            train_every=train_every,
            use_learnable_dist=use_learnable_dist,
            experiment=experiment,
            ray_mode=ray_mode,
            seed=seed,
            mixed_at_every_step=mixed_at_every_step,
            independent_learning=independent_learning,
            evaluation_worker_num=evaluation_worker_num,
            distribution_training_kwargs=distribution_training_kwargs,
            centralized_critic_config=centralized_critic_config,
            rectifier_type=rectifier_type,
            mini_epoch=mini_epoch,
        )

        assert (
            not use_learnable_dist
        ), "Bilevel optimization does not support learnable meta strategies yet"

        self.max_bilevel_step = max_bilevel_step
        self._bi_solver_type = bi_solver_type
        self._compare_mode = compare_mode

    def _evaluate_comb(
        self,
        active_policy_mapping: Dict[AgentID, PolicyID],
        fixed_policy_mapping: Dict[AgentID, Sequence[PolicyID]],
    ) -> Dict[AgentID, Sequence[Tuple[Dict, Dict[AgentID, float]]]]:
        """Evaluate policy combination with given active policy mapping and fixed policy mapping.

        Args:
            active_policy_mapping (Dict[AgentID, PolicyID]): A dict of active policy mapping.
            fixed_policy_mapping (Dict[AgentID, Sequence[PolicyID]]): A dict of fixed policy mapping.

        Returns:
            Dict[AgentID, Sequence[Tuple[Dict, Dict[AgentID, float]]]]: A dict of feedback.
        """

        res = {}
        tasks = []
        for agent in self.agents:
            learner = self._agent_learners[agent]
            simulation = []
            other_pids = [fixed_policy_mapping[k] for k in self.agents if k != agent]
            other_agents = [k for k in self.agents if k != agent]
            for x in itertools.product(*other_pids):
                m = dict(zip(other_agents, x))
                assert isinstance(
                    active_policy_mapping[agent], PolicyID
                ), active_policy_mapping[agent]
                m[agent] = active_policy_mapping[agent]
                simulation.append(m)
            if self.ray_mode:
                tasks.append(
                    learner.evaluation.remote(
                        simulation,
                        max_step=self._rollout_config.max_step,
                        fragment_length=self._rollout_config.num_simulation
                        * self._rollout_config.max_step,
                    )
                )
            else:
                res[agent] = learner.evaluation(
                    simulation,
                    max_step=self._rollout_config.max_step,
                    fragment_length=self._rollout_config.num_simulation
                    * self._rollout_config.max_step,
                )
                self._payoff_matrix.update_payoff_and_simulation_status(res[agent])

        if len(tasks) > 0:
            tmp = ray.get(tasks)
            for i, agent in enumerate(self.agents):
                res[agent] = tmp[i]
                self._payoff_matrix.update_payoff_and_simulation_status(tmp[i])
        return res

    def _optimize_best_response(
        self,
        active_policy_mapping: Dict[AgentID, PolicyID],
        fixed_policy_mapping: Dict[AgentID, Sequence[PolicyID]],
        equilibrium: Dict[AgentID, Dict[PolicyID, float]],
        sampler_config: Dict[str, Any],
        inner_stop_conditions: Dict[str, Any],
    ) -> Dict[AgentID, np.ndarray]:
        """Run RL to optimization best response.

        Args:
            equilibrium (Dict[AgentID, Dict[PolicyID, float]]): A dict of equilibrium
            sampler_config ([type]): Configuration of sampler
            inner_stop_conditions ([type]): [description]

        Returns:
            Dict[AgentID, np.ndarray]: [description]
        """

        if self._use_learnable_dist:
            equilibrium = generate_random_from_shapes(equilibrium)
        if self.ray_mode:
            ray.get(
                [
                    learner.set_behavior_dist.remote(equilibrium)
                    for learner in self._agent_learners.values()
                ]
            )
        else:
            for aid, learner in self._agent_learners.items():
                learner.set_behavior_dist(equilibrium)

        if self.ray_mode:
            ray.get(
                [
                    learner.learn.remote(
                        sampler_config, stop_conditions=inner_stop_conditions
                    )
                    for learner in self._agent_learners.values()
                ]
            )
        else:
            for aid, learner in self._agent_learners.items():
                learner.learn(sampler_config, stop_conditions=inner_stop_conditions)

        # evaluate
        # generate combs
        evaluation_results = self._evaluate_comb(
            active_policy_mapping, fixed_policy_mapping
        )
        Log.debug("\t* evaluation results:\n{}".format(evaluation_results))
        learning_res = self.payoff_matrix.dict_to_matrix(evaluation_results)

        return learning_res

    def _gen_partial_matrix(
        self, evaluation_results: Dict[AgentID, np.ndarray]
    ) -> Dict[AgentID, np.ndarray]:
        """Convert evaluation results to partial matrix.

        Args:
            evaluation_results (Dict[AgentID, np.ndarray]): An agent dict of evaluation results.

        Returns:
            Dict[AgentID, np.ndarray]: A dict of agent matrix.
        """

        sorted_agent_mapping: Dict[AgentID, int] = self.payoff_matrix.agent_axes_mapping
        res = {}
        for agent, idx in sorted_agent_mapping.items():
            res[agent] = evaluation_results[agent].copy()
            res[agent] = np.expand_dims(res[agent], axis=idx)
        return res

    def _compute_equilibrium(
        self,
        fixed_policy_mapping: Dict[AgentID, Sequence[PolicyID]],
        evaluation_results: Dict[AgentID, np.ndarray],
        ori_payoff_matrix: Dict[AgentID, np.ndarray],
    ) -> Dict[AgentID, Dict[PolicyID, float]]:
        """Compute equilibrium with modified payoff matrix, where equals to the diff between
        `evaluation_results` and `ori_payoff_matrix`.

        Args:
            evaluation_results (Dict[AgentID, np.ndarray]): An agent dict of evaluation results.
            ori_payoff_matrix (Dict[AgentID, np.ndarray]): The original agent payoff matrix.
        """

        # flush payoff table with given evaluation results
        matrix = {}
        # then remove self-axis shape
        partial_matrix: Dict[AgentID, np.ndarray] = self._gen_partial_matrix(
            evaluation_results
        )
        for agent, sub_matrix in ori_payoff_matrix.items():
            # self-agent to other agent policies
            # agent_evals = evaluation_results[agent]
            # convert agent eval keys to index
            matrix[agent] = sub_matrix - partial_matrix[agent]
            Log.debug(
                "\n* partial_matrix:\n{}\n* ori matrix:\n{}\n* equima:\n{}".format(
                    partial_matrix[agent], sub_matrix, matrix[agent]
                )
            )

        equilibrium_mat = self.meta_solver.solve(matrix)
        equilibrium = {
            k: dict(
                zip(
                    fixed_policy_mapping[k],
                    equilibrium_mat[k],
                )
            )
            for k, v in matrix.items()
        }
        return equilibrium

    def _learn_equilibrium(self):
        # collect batches from all sub learners
        batches = {}
        for agent, learner in self._agent_learners.items():
            batches[agent] = learner._sampler.sample(agent_filter=None)

    def _bi_level_optimize(
        self,
        before_nash_conv,
        sampler_config: Union[Dict, LambdaType],
        stopper,
        active_policy_ids,
        equilibrium,
        fixed_policies,
        inner_stop_conditions: Dict,
    ):

        bi_step = 0
        fixed_policy_mapping = {
            agent: sorted(list(pdict.keys())) for agent, pdict in fixed_policies.items()
        }
        ori_matrix = self.payoff_matrix.get_sub_matrix(fixed_policy_mapping)
        single_active_policy_ids = {k: v[0] for k, v in active_policy_ids.items()}

        # before_nash_conv = measure_exploitability(
        #     self._env_description["config"]["env_id"], fixed_policies, equilibrium
        # )
        # Log.info("\t* base nash conv: %s", before_nash_conv)

        nash_conv = None

        while bi_step < self.max_bilevel_step:
            evaluation_results = self._optimize_best_response(
                single_active_policy_ids,
                fixed_policy_mapping,
                equilibrium,
                sampler_config,
                inner_stop_conditions,
            )

            equilibrium = self._compute_equilibrium(
                fixed_policy_mapping,
                evaluation_results,
                ori_payoff_matrix=ori_matrix,
            )
            # TODO(): track regret
            # regret = None
            bi_step += 1
            Log.info(
                "\t* equilibrium after {} step bi-opt is: {}".format(
                    bi_step, equilibrium
                )
            )

            if stopper.counter == 0:
                break

            nash_conv = measure_exploitability(
                self._env_description["config"]["env_id"],
                fixed_policies,
                equilibrium,
            )

            Log.info(
                "\t* Nash diff: %s",
                before_nash_conv.nash_conv - nash_conv.nash_conv,
            )

            if before_nash_conv.nash_conv - nash_conv.nash_conv >= 1e-4:
                break

        self._update_agent_matrix(equilibrium)
        return nash_conv

    def _psro_optimize(
        self,
        sampler_config: Union[Dict, LambdaType],
        stopper,
        active_policy_ids,
        equilibrium,
        fixed_policies,
        inner_stop_conditions: Dict,
    ):
        learning_res = {}
        if self.ray_mode:
            ray.get(
                [
                    learner.learn.remote(
                        sampler_config, stop_conditions=inner_stop_conditions
                    )
                    for learner in self._agent_learners.values()
                ]
            )
            ray.get(
                [
                    learner.set_ego_policy_fixed.remote(active_policy_ids)
                    for learner in self._agent_learners.values()
                ]
            )
        else:
            for aid, learner in self._agent_learners.items():
                learning_res[aid] = learner.learn(
                    sampler_config, stop_conditions=inner_stop_conditions
                )
                # then set fixed
                learner.set_ego_policy_fixed(active_policy_ids)

    def learn(
        self,
        sampler_config: Union[Dict, LambdaType],
        stop_conditions: Dict,
        inner_stop_conditions: Dict,
    ):
        """Main loop performs bi-level optimization.

        Args:
            sampler_config (Union[Dict, LambdaType]): The sample configuration.
            stop_conditions (Dict): Stop conditions.
            inner_stop_conditions (Dict): Inner stop conditions.
        """

        # 1. generate simultations
        # training object should be an algorithm, not agent individual
        stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        stopper.reset()

        nash_conv = None
        bi_nash_conv = None
        round = 0
        while not stopper.is_terminal():
            Log.info("Global iteration on: %s", stopper.counter)
            self._reset_indep_summary(round)

            # a list of list of dict
            if not self._use_learnable_dist:
                equilibrium, fixed_policies = self._run_centralized_meta_solver()
            else:
                equilibrium, fixed_policies = self._run_decentralized_meta_solver()

            # TODO(): print as tabulate
            Log.info("\t* computed equilibrium as: %s", equilibrium)

            # 2. generate trainable policies (a dict of dict policies) {agent: {pid: policy}}
            if self.ray_mode:
                ray.get(
                    [
                        learner.add_policy.remote(n_support=1)
                        for learner in self._agent_learners.values()
                    ]
                )
            else:
                for agent, learner in self._agent_learners.items():
                    learner.add_policy(n_support=1)

            active_policy_ids = {}

            if self.ray_mode:
                res = ray.get(
                    [
                        learner.get_ego_active_policy_ids.remote()
                        for learner in self._agent_learners.values()
                    ]
                )
                for e in res:
                    active_policy_ids.update(e)
            else:
                for agent, learner in self._agent_learners.items():
                    active_policy_ids.update(learner.get_ego_active_policy_ids())

            Log.info("\t* added active policy ids: {}".format(active_policy_ids))

            nash_conv = measure_exploitability(
                self._env_description["config"]["env_id"], fixed_policies, equilibrium
            )
            Log.info("\t* base nash conv: %s", nash_conv)

            # then notify payoff manager
            self.payoff_matrix.expand(active_policy_ids)
            if self._compare_mode:
                # update last time's statictics
                if bi_nash_conv is not None:

                    write_to_tensorboard(
                        self.summary_writer,
                        info={
                            "NashConv/base": before_nash_conv.nash_conv,
                            "NashConv/optimized": nash_conv.nash_conv,
                            "NashConv/bi_optimized": bi_nash_conv.nash_conv,
                            "NashConv/diff": nash_conv.nash_conv
                            - bi_nash_conv.nash_conv,
                        },
                        global_step=stopper.counter,
                        prefix="",
                    )
                    Log.info(
                        "\t* compared nash conv diff: %s",
                        bi_nash_conv.nash_conv - nash_conv.nash_conv,
                    )
                before_nash_conv = nash_conv

                # update next times
                summary_writer = self.summary_writer
                self.summary_writer = None
                bi_nash_conv = self._bi_level_optimize(
                    before_nash_conv,
                    sampler_config,
                    stopper,
                    active_policy_ids,
                    equilibrium,
                    fixed_policies,
                    inner_stop_conditions,
                )
                self.summary_writer = summary_writer
                # recover all the ego_policies parameters
                if self.ray_mode:
                    ray.get(
                        [
                            learner.recover_ego_policies.remote()
                            for learner in self._agent_learners.values()
                        ]
                    )
                    ray.get(
                        [
                            learner.set_behavior_dist.remote(equilibrium)
                            for learner in self._agent_learners.values()
                        ]
                    )
                else:
                    for learner in self._agent_learners.values():
                        learner.recover_ego_policies()
                    for aid, learner in self._agent_learners.items():
                        learner.set_behavior_dist(equilibrium)

                self._psro_optimize(
                    sampler_config,
                    stopper,
                    active_policy_ids,
                    equilibrium,
                    fixed_policies,
                    inner_stop_conditions,
                )
                nash_conv = before_nash_conv
                # bi_nash_conv = ray.get(_bi_result)

            else:
                before_nash_conv = nash_conv

                # self._update_agent_matrix(equilibrium) in _bi_level_optimize
                nash_conv = self._bi_level_optimize(
                    before_nash_conv,
                    sampler_config,
                    stopper,
                    active_policy_ids,
                    equilibrium,
                    fixed_policies,
                    inner_stop_conditions,
                )

                if nash_conv is None:
                    nash_conv = before_nash_conv

                write_to_tensorboard(
                    self.summary_writer,
                    info={
                        "NashConv/base": before_nash_conv.nash_conv,
                        "NashConv/optimized": nash_conv.nash_conv,
                        "NashConv/diff": before_nash_conv.nash_conv
                        - nash_conv.nash_conv,
                    },
                    global_step=stopper.counter,
                    prefix="",
                )

                if self.ray_mode:
                    ray.get(
                        [
                            learner.set_ego_policy_fixed.remote(active_policy_ids)
                            for learner in self._agent_learners.values()
                        ]
                    )
                else:
                    for aid, learner in self._agent_learners.items():
                        learner.set_ego_policy_fixed(active_policy_ids)

            # 6. sync policies
            self._sync_up()
            Log.info("\t* active policies trained and synced up")
            Log.info("")
            # stopper is used to compute exploitability
            stopper.step(None, None, None, None)
