import gym

from expground.types import DataArray, Dict, Union, Sequence, Tuple
from expground.algorithms.base_policy import Policy
from expground.common.distributions import Distribution, make_proba_distribution


def _caster(observation):
    raise NotImplementedError


def info_caster() -> type:
    # convert state to info str
    _caster = None

    def decorator(func):
        def wrap(self, observation, action_mask, evaluate):
            info_str = _caster(observation)
            rets = func(self, info_str, action_mask, evaluate)
            return rets

        return wrap

    return decorator


def _policy_holder(action_space: gym.Space):
    raise NotImplementedError


class TabularPolicy(Policy):
    def __init__(
        self, observation_space: gym.Space, action_space: gym.Space, is_fixed: bool
    ):
        super(TabularPolicy, self).__init__(
            observation_space, action_space, None, None, is_fixed=is_fixed
        )

        # generate state matrix
        self._table: Dict[str, Distribution] = dict()
        self._action_dist_handler = make_proba_distribution(action_space)

    @property
    def table(self) -> Dict[str, Distribution]:
        return self._table

    @property
    def action_dist_handler(self) -> Distribution:
        return self._action_dist_handler

    @info_caster
    def compute_action(
        self, info_str, action_mask, evaluate
    ) -> Tuple[int, Sequence[float]]:
        return self._compute_action(info_str, action_mask, evaluate)

    def _compute_action(
        self, info_str: str, action_mask: DataArray, evaluate: bool
    ) -> Tuple[int, Sequence[float]]:
        raise NotImplementedError

    def compute_actions(self, observation, action_mask):
        raise NotImplementedError
