import copy

import gym
from gym.spaces import Discrete


class DiscreteWrapper(gym.Wrapper):
    def __init__(self, env, n_actions, disc2cont):
        super().__init__(env)
        self._n_actions = n_actions
        self._disc2cont = disc2cont
        self._action_space = Discrete(n_actions)
        self._spec = copy.deepcopy(env.spec)
        self._spec.action_space = self._action_space
        if hasattr(env, "get_train_envs"):
            setattr(self, "get_train_envs", self._get_train_envs)
        if hasattr(env, "get_test_envs"):
            setattr(self, "get_test_envs", self._get_test_envs)

    @property
    def action_space(self):
        return self._action_space

    @property
    def spec(self):
        return self._spec

    def step(self, action):
        return self.env.step(self._disc2cont(action))

    def _get_train_envs(self):
        assert hasattr(self._env, "get_train_envs")
        return [
            DiscreteWrapper(env, self._n_actions, self._disc2cont)
            for env in self.env.get_train_envs()
        ]

    def _get_test_envs(self):
        assert hasattr(self._env, "get_test_envs")
        return [
            DiscreteWrapper(env, self._n_actions, self._disc2cont)
            for env in self.env.get_test_envs()
        ]
