import gym
import numpy as np

from utils import onehot


class OneHot(gym.ObservationWrapper):
    def __init__(self, env):
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        super().__init__(env)
        n = self.observation_space.n
        self.observation_space = gym.spaces.Box(0.0, 1.0, shape=[n], dtype=np.float32)

    def observation(self, obs):
        return onehot(obs, self.observation_space.shape, self.observation_space.dtype)


class OneHotCoordinate(gym.ObservationWrapper):
    def __init__(self, env):
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        super().__init__(env)
        self.decode = env.unwrapped._decode
        self.dims = env.unwrapped._dims
        self.observation_space = gym.spaces.Box(0.0, 1.0, shape=[np.sum(self.dims)], dtype=np.float32)

    def observation(self, obs):
        x, y = self.decode(obs)
        dtype = self.observation_space.dtype
        X = onehot(x, self.dims[0], dtype)
        Y = onehot(y, self.dims[1], dtype)
        return np.concatenate([X, Y])


class NewAPI(gym.Wrapper):
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        truncated = False
        return obs, reward, done, truncated, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        info = {}
        return obs, info


class OldAPI(gym.Wrapper):
    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        return obs, reward, done, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        return obs
