from gym.spaces import *
from ARLLib.utils import *

# TODO: in BRL, Space.sample() is expected to return a random sample of the space following the UNIFORM distribution.
#  Current implementation of this routine (across all spaces) rely on gym.spaces.space.Sample which do NOT guarantee
#  uniform sampling. But this is not urgent issue for now, as in gym 0.15.4, gym.spaces.space. Sample() indeed uniformly
#  samples for two special spaces: Discrete and Bounded Box, and these two special spaces account for most of our
#  applications at the moment.

class Categorical(Discrete):
    def __init__(self, n):
        self.n = None
        super().__init__(n)

    def onehot(self, x):
        if isinstance(x, numbers.Integral) and 0 <= x < self.n:
            return onehot_encoding(x, self.n)
        elif isinstance(x, np.ndarray) and x.shape == (self.n,):
            return onehot_decoding(x)
        else:
            raise RuntimeError('categorical space of dimension {0} cannot do one-hot conversion on {1} {2}'.format(
                self.n, type(x), 'with value {0}'.format(x)))

    def sample(self):
        return self.onehot(super().sample())

    def contains(self, x):
        return x.sum() == x.max() == 1 and super().contains(self.onehot(x))

    def __repr__(self):
        return "Categorical(dim=%d)" % self.n

    def __eq__(self, other):
        return isinstance(other, Categorical) and self.n == other.n


"""
example usage and simple unit test
"""
if __name__ == "__main__":
    space = Categorical(5)
    print(space)
    actions = RV('categorical actions')
    x = space.sample()
    print(x)
    print(space.contains(x))
    print(space.contains(array([0,0,0,1,0])))
    print(space.contains(array([0,0,0,0,0])))
    print(space.onehot(2))
    print(space.onehot(array([1,0,0,0,0])))

    for _ in range(10000):
        x = space.sample()
        actions.append(space.onehot(x))
    print(actions)
    actions.plot()

    try:
        print(space.onehot(5))  # raise error
    except:
        print('error raised.')

    try:
        print(space.onehot([0,1,0,0,0]))  # raise error
    except:
        print('error raised.')

    try:
        print(space.onehot(array([0, 1, 0])))  # raise error
    except:
        print('error raised.')

    print('finished')
