import numpy as np

from gym import spaces

from expground.algorithms.base_policy import Policy
from expground.utils import preprocessor
from expground.utils.preprocessor import get_preprocessor


class RandomPolicy(Policy):
    def __init__(self, observation_space, action_space, model_config, custom_config):
        super().__init__(observation_space, action_space, model_config, custom_config)
        self._preprocessor = get_preprocessor(observation_space)(observation_space)
        self._is_discreate = isinstance(self._action_space, spaces.Discrete)

    def compute_action(self, observation, action_mask, explore):
        """Support only discrete cases.

        Args:
            observation ([type]): [description]
            action_mask ([type]): [description]
            explore ([type]): [description]

        Returns:
            [type]: [description]
        """

        batched = True
        if action_mask is not None:
            action_prob = np.ones_like(action_mask) / action_mask.shape[-1]
        else:

            action_prob = [1 / self._action_space.n] * self._action_space.n
        if observation.shape == self._preprocessor.shape:
            observation = [observation]
            action_mask = [action_mask] if action_mask is not None else None
            batched = False

        if action_mask is not None:
            actions = list(
                map(
                    lambda e: np.random.choice(np.argwhere(e == 1).reshape((-1,))),
                    action_mask,
                )
            )
        else:
            actions = np.random.choice(
                self._action_space.n, 1 if not batched else len(observation)
            )
        if not batched:
            actions = actions[0]
        return actions, action_prob

    def compute_actions(self, **kwargs):
        pass

    @property
    def preprocessor(self):
        return self._preprocessor
