from typing import Tuple
import nashpy as nash
import numpy as np

from open_spiel.python.egt import alpharank, utils as alpharank_utils

from expground.types import AgentID, DataArray, Dict, Sequence, List
from expground.logger import Log


class MetaSolver:
    """Support only two players"""

    def __init__(self, solve_method="fictitious_play"):
        self._solve_method = solve_method
        Log.debug("created Meta solver with method=%s", solve_method)

    @classmethod
    def from_type(cls, type_name: str):
        """Build a meta solver from type name.
        Args:
            type_name (str): Type name, current is as the same as solve_method.
        Returns:
            MetaSolver: An instance of MetaSolver
        """

        return cls(type_name)

    def fictitious_play(self, payoffs_seq):
        game = nash.Game(*payoffs_seq)

        *_, eqs = iter(game.fictitious_play(iterations=2000))
        eqs = [tuple(map(lambda x: x / np.sum(x), eqs))]
        return eqs[0]

    def mrcp(self, payoffs_seq):
        raise NotImplementedError

    def alpharank(self, payoffs_seq):
        def remove_epsilon_negative_probs(probs, epsilon=1e-9):
            """Removes negative probabilities that occur due to precision errors."""
            if len(probs[probs < 0]) > 0:  # pylint: disable=g-explicit-length-test
                # Ensures these negative probabilities aren't large in magnitude, as that is
                # unexpected and likely not due to numerical precision issues
                print("Probabilities received were: {}".format(probs[probs < 0]))
                assert np.alltrue(
                    np.min(probs[probs < 0]) > -1.0 * epsilon
                ), "Negative Probabilities received were: {}".format(probs[probs < 0])

                probs[probs < 0] = 0
                probs = probs / np.sum(probs)
            return probs

        def get_alpharank_marginals(payoff_tables, pi):
            """Returns marginal strategy rankings for each player given joint rankings pi.
            Args:
              payoff_tables: List of meta-game payoff tables for a K-player game, where
                each table has dim [n_strategies_player_1 x ... x n_strategies_player_K].
                These payoff tables may be asymmetric.
              pi: The vector of joint rankings as computed by alpharank. Each element i
                corresponds to a unique integer ID representing a given strategy profile,
                with profile_to_id mappings provided by
                alpharank_utils.get_id_from_strat_profile().
            Returns:
              pi_marginals: List of np.arrays of player-wise marginal strategy masses,
                where the k-th player's np.array has shape [n_strategies_player_k].
            """
            num_populations = len(payoff_tables)

            if num_populations == 1:
                return pi
            else:
                num_strats_per_population = (
                    alpharank_utils.get_num_strats_per_population(
                        payoff_tables, payoffs_are_hpt_format=False
                    )
                )
                num_profiles = alpharank_utils.get_num_profiles(
                    num_strats_per_population
                )
                pi_marginals = [np.zeros(n) for n in num_strats_per_population]
                for i_strat in range(num_profiles):
                    strat_profile = alpharank_utils.get_strat_profile_from_id(
                        num_strats_per_population, i_strat
                    )
                    for i_player in range(num_populations):
                        pi_marginals[i_player][strat_profile[i_player]] += pi[i_strat]
                return pi_marginals

        joint_distr = alpharank.sweep_pi_vs_epsilon(payoffs_seq)
        joint_distr = remove_epsilon_negative_probs(joint_distr)
        marginals = get_alpharank_marginals(payoffs_seq, joint_distr)

        return marginals

    def solve(self, payoff_dict: Dict[AgentID, DataArray], solve_method: str = None):
        solve_method = solve_method or self._solve_method
        if solve_method == "fictitious_play":
            res = self.fictitious_play(payoff_dict.values())
        elif solve_method == "alpharank":
            res = self.alpharank(payoff_dict.values())
        elif solve_method == "mrcp":
            res = self.mrcp(payoff_dict.values())
        return dict(zip(payoff_dict.keys(), res))

    def batched_solve(
        self, payoff_dict: Dict[AgentID, DataArray], solver_types: List[str]
    ) -> List[Dict[AgentID, Tuple]]:
        """Compute multiple group of meta strategies with different solvers.

        Args:
            payoff_dict ([type]): A dict of payoffs, mapping from agent to a data array
            types (List[str]): A list of solver types

        Returns:
            List[Dict[AgentID, Tuple]]: A list of meta strategies, corresponding to different solver types.
        """

        res = []
        for _type in solver_types:
            Log.debug("solve meta strategy with solver={} ...".format(_type))
            res.append(self.solve(payoff_dict, _type))
        return res
