from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
import numpy as np

from expground.envs.python import game_type
from expground.types import AgentID, Dict, Any, Tuple, Sequence


def _memoize_method(method):
    """Memoize a single-arg instance method using an on-object cache."""
    cache_name = "cache_" + method.__name__

    def wrap(self, arg):
        key = str(arg)
        cache = vars(self).setdefault(cache_name, {})
        if key not in cache:
            cache[key] = method(self, arg)
        return cache[key]

    return wrap


@dataclass
class GameType:
    short_name: str
    long_name: str
    dynamics: game_type.Dynamics
    chance_mode: game_type.ChanceMode
    information: game_type.Information
    utility: game_type.Utility
    reward_model: game_type.RewardModel
    max_num_players: int
    min_num_players: int
    provides_information_state_string: bool
    provides_information_state_tensor: bool
    provides_observation_string: bool
    provides_observation_tensor: bool
    provides_factored_observation_string: bool = False
    """Can we factorize observations into public and private parts?
    This is similar to observation fields before, but adds additional distinction between public and private observations.
    """
    default_loadable: bool = True
    """Can the game be loaded with no parameters? It is strongly recommended that games be loadable with default arguments."""


@dataclass
class GameInfo:
    num_distinct_actions: int
    """the size of the action space. See `Game` for a full description."""

    max_chance_outcomes: int
    """Maximum number of distinct chance outcomes for chance nodes in the game."""

    num_player: int
    """The number of players in this instantiation of the game. Does not include the chance-player."""

    min_utility: float
    """Indicates the lower bound of the utility range."""

    max_utility: float
    """Indicates the upper bound of the utility range."""

    utility_sum: float
    """The total utility of all players, if this is a constant-sum-utility game, and zero for zero-sum games."""

    max_game_length: int
    """The maximum number of player decisions in a game. Does not include chance events."""


@dataclass
class IIGObservationType:
    private_info: game_type.PrivateInfoType = game_type.PrivateInfoType.SINGLE_PLAYER
    """"""
    public_info: bool = True
    perfect_recall: bool = True


class State(metaclass=ABCMeta):
    def __init__(self, game):
        self._game = game
        self._move_number = 0
        self._player_action_tups = []

    def legal_actions(self, player: AgentID) -> Tuple:
        if self.is_terminal() and player == self.current_player():
            return (
                self.chance_outcomes()
                if self.is_chance_node()
                else self.self._legal_actions(player)
            )
        else:
            return ()

    def legal_actions_mask(self):
        return self._legal_actions_mask(self.current_player())

    def _legal_actions_mask(self, player: AgentID):
        length = (
            self._game.game_info.max_chance_outcomes
            if player == game_type.PlayerId.CHANCE
            else self.num_distinct_actions
        )
        mask = np.zeros(length, dtype=np.int)
        for action in self._legal_actions(player):
            mask[action] = 1
        return mask

    def action_to_string(self, action):
        return self._action_to_string(self.current_player(), action)

    def string_to_action(self, action_str):
        return self._string_to_action(self.current_player(), action_str)

    def _string_to_action(self, player: AgentID, action_str: str):
        for action in self.legal_actions():
            if action_str == self._action_to_string(player, action):
                return action
        raise IndexError(f"No such action embeded as {action_str}")

    def apply_action(self, action):
        assert action in self.legal_actions()
        player = self.current_player()
        self._apply_action(action)
        self._player_action_tups.append((player, action))
        self._move_number += 1

    def apply_actions(self, actions: Dict[AgentID, "Action"]):
        """For simultaneous game mode.

        Args:
            actions (Dict[AgentID, Action]): The dict of actions.
        """

        self._apply_actions(actions)
        for tup in actions.items():
            self._player_action_tups.append(tup)
        self._move_number += 1

    def rewards(self):
        if self.is_terminal():
            return self.returns()
        else:
            assert self.is_chance_node()
            return [0.0] * self.num_player

    def player_reward(self, player: AgentID):
        return self.rewards()[player]

    def player_return(self, player: AgentID):
        return self.returns()[player]

    @abstractmethod
    def _legal_actions_mask(self, player: AgentID):
        """..."""

    @property
    def num_player(self) -> int:
        return self._game.game_info.num_player

    @abstractmethod
    def current_player(self):
        """Returns id of the next player to move, or TERMINAL if game is over."""

    @abstractmethod
    def _legal_actions(self, player: AgentID):
        """Returns a list of legal actions, sorted in ascending order.

        Args:
            player (AgentID): The player id.
        """

    @abstractmethod
    def _apply_action(self, action: Any):
        """Applies the specified action to the state.

        Args:
            action (Any): The action.
        """

    def is_chance_node(self) -> bool:
        """Return true if current player id equals to `PlayerID.CHANCE`

        Returns:
            bool: True when chance node, otherwise False.
        """

        return self.current_player() == game_type.PlayerId.CHANCE

    def is_mean_field_node(self) -> bool:
        """Whether current player id is mean field id.

        Returns:
            bool: True when mean-field player, otherwise False.
        """

        return self.current_player() == game_type.PlayerId.MEAN_FIELD

    def is_player_node(self) -> bool:
        """Whether current player id largers than `PlayerID.DEFAULT`.

        Returns:
            bool: True when player, otherwise False.
        """

        return self.current_player() >= game_type.PlayerId.DEFAULT

    def is_simultaneous_node(self) -> bool:
        return self.current_player() == game_type.PlayerId.SIMULTANEOUS

    def is_player_acting(self, player: game_type.PlayerId) -> bool:
        assert player >= game_type.PlayerId.DEFAULT
        assert player < self.num_player
        return self.current_player() == player or self.is_simultaneous_node()

    @property
    def player_action(self):
        return (self.current_player(), action)

    def history(self):
        action_list = [e[1] for e in self._player_action_tups]
        return action_list

    def full_history(self):
        return self._palyer_action_tups

    def history_string(self) -> str:
        return ", ".join(map(str, self.history()))

    def move_number(self):
        return self._move_number

    def is_initial_state(self) -> bool:
        return len(self._palyer_action_tups) == 0

    def information_state_string(self) -> str:
        return self._information_state_string(self, self.current_player())

    @abstractmethod
    def information_state_tensor(self, player, value):
        """"""

    def observation_string(self):
        return self._observation_string(self.current_player())

    @abstractmethod
    def _observation_string(self, player: AgentID):
        """"""

    def observation_tensor(self, player, value):
        """"""

    @abstractmethod
    def child(self, action) -> "State":
        """"""

    def undo_action(self, player: AgentID, action):
        raise NotImplementedError

    @property
    def num_distinct_actions(self):
        return self._game.game_info.num_distinct_actions

    @abstractmethod
    def chance_outcomes(self) -> Dict["Action", float]:
        """Returns the possible chance outcomes and their probabilities."""

    def legal_chance_outcomes(self):
        raise NotImplementedError

    def resample_info_state(self, player: AgentID):
        raise NotImplementedError

    def get_histories_consistent_with_infostate(self):
        return self._get_histories_consistent_with_infostate(self.current_player())

    def _get_history_consistent_with_infostate(self, player: AgentID):
        raise NotImplementedError

    @abstractmethod
    def _action_to_string(self, player, action) -> str:
        """Action -> string."""

    @abstractmethod
    def is_terminal(self) -> bool:
        """Returns True if the game is over."""

    @abstractmethod
    def returns(self):
        """Total reward for each player over the course of the game so far."""

    @abstractmethod
    def information_state_string(self, player: AgentID) -> str:
        """..."""

    def mean_field_population(self):
        if self._game.game_type.dynamics == game_type.Dynamics.MEANFIELD:
            raise NotImplementedError
        return 0


class Observer:
    pass


class Game(metaclass=ABCMeta):
    def __init__(
        self, game_type: GameType, game_info: GameInfo, params: Dict[str, Any]
    ) -> None:
        self._game_type = game_type
        self._game_info = game_info
        self._params = params
        self._players = [i for i in range(game_info.num_player)]

    @property
    def game_type(self) -> GameType:
        return self._game_type

    @property
    def game_info(self) -> GameInfo:
        return self._game_info

    @abstractmethod
    def new_initial_state(self) -> State:
        """Return initial state instance"""

    def new_initial_states(self) -> Sequence[State]:
        if (
            self._game_type.dynamics == game_type.Dynamics.MEANFIELD
            and self._game_info.num_player >= 2
        ):
            states = [
                self.new_initial_state_for_population(player)
                for player in self._players
            ]
        else:
            states = [self.new_initial_state()]
        return states

    def new_initial_state_for_population(self, player: AgentID):
        raise NotImplementedError
