import gym
from gym.envs.toy_text.frozen_lake import FrozenLakeEnv
import numpy as np


class FrozenLakeWrapper(FrozenLakeEnv):

  def __init__(self, num_states):
    super().__init__()
    self.num_states = num_states

  def step(self, action):
    action = np.argmax(action, axis=-1)
    obs, reward, done, info = super().step(action)
    obs_onehot = np.zeros(self.num_states, np.float32)
    obs_onehot[obs] = 1.

    return obs_onehot, reward, done, info

  def reset(self):
    obs = super().reset()
    obs_onehot = np.zeros(self.num_states, np.float32)
    obs_onehot[obs] = 1.
    return obs_onehot
