from typing import Callable

from gym.spaces import Discrete
import numpy as np
from gym_montezuma.envs.errors import *

class ConditionalDiscrete(Discrete):
    def __init__(self, n_actions: int, available_actions: Callable[[], np.ndarray]):
        """
        Create a discrete action space where only valid actions can be sampled
        :param n_actions: the total number of actions
        :param available_actions: a function that, when called, returns a binary array specifying whether actions are
        runnable in the environment's current state
        """
        super().__init__(n_actions)
        self.n_actions = n_actions
        self.available_actions = available_actions

    def sample(self):
        """
        Sample a valid action uniformly at random
        """
        mask = self.available_actions()
        if not np.any(mask):
            raise NoSkillsAvailable()
        return np.random.choice(np.arange(self.n), p=mask / mask.sum())

    def __repr__(self):
        return "ConditionalDiscrete(%d), mask_func: %r" % (self.n_actions, self.available_actions)

    def __str__(self):
        return "ConditionalDiscrete(%d)".format(self.n_actions)
