import gym
import numpy as np
from gym import spaces
import itertools


class ActionRedundancyWrapper(gym.Wrapper):
    """
    Transform action space into redundant action space.
    Original action space: [-1, 1]
    New action space: [-nleft, nright]
    Re-normalizing action space:
    If e < -1/nleft, a = -1
    If -1/nleft < e < 0, a = e * nleft
    If 0 < e < 1/nright, a = e * nright
    If e > 1/nright, a = 1

    :param env:
    """

    def __init__(self, env: gym.Env, nleft: float = 1.0, nright: float = 1.0):
        super().__init__(env)
        # assert np.all(env.action_space.low == -1.) and np.all(env.action_space.high == 1.), "Action Space is not [-1, 1]"
        if not (  np.all(env.action_space.low == -1.) and np.all(env.action_space.high == 1.)  ):
            print("Action Space is not [-1, 1]")
            self.original_low = env.action_space.low
            self.original_high = env.action_space.high
            self.need_scale_up = True
            env.action_space = spaces.Box(low=-np.ones_like(env.action_space.low), high=np.ones_like(env.action_space.high))
        else:
            self.need_scale_up = False
        self.nleft = nleft
        self.nright = nright

    def step(self, action):
        new_action = np.zeros_like(action)
        mask1 = (action <= -1. / self.nleft)
        mask2 = ((-1. / self.nleft < action) & (action <= 0.))
        mask3 = ((0. < action) & (action <= 1. / self.nright))
        mask4 = (action > 1. / self.nright)

        new_action[mask1] = -1.
        new_action[mask2] = action[mask2] * self.nleft
        new_action[mask3] = action[mask3] * self.nright
        new_action[mask4] = 1.
        # assert self.nright == self.nleft, "asymmetric redundancy"
        # new_action = np.clip(action * self.nright, -1, 1)

        if self.need_scale_up:
            new_action = (new_action + 1 ) * 0.5 * (self.original_high- self.original_low) + self.original_low

        return self.env.step(new_action)

class ActionRedundancyWrapper2nd(gym.Wrapper):
    """
    Transform action space into redundant action space.
    Original action space: [-1, 1]
    New action space: [-nleft, nright]
    Re-normalizing action space:
    If e < -1 + 1/nleft, a = e * nleft + nleft - 1
    If -1 + 1/nleft < e < 0, a = 0
    If 0 < e < 1 - 1/nright, a = 0
    If e > 1 - 1/nright, a = e * nright - nright + 1

    :param env:
    """

    def __init__(self, env: gym.Env, nleft: float = 1.0, nright: float = 1.0):
        super().__init__(env)
        # assert np.all(env.action_space.low == -1.) and np.all(env.action_space.high == 1.), "Action Space is not [-1, 1]"

        if not (  np.all(env.action_space.low == -1.) and np.all(env.action_space.high == 1.)  ):
            print("Action Space is not [-1, 1]")
            self.original_low = env.action_space.low
            self.original_high = env.action_space.high
            self.need_scale_up = True
            env.action_space = spaces.Box(low=-np.ones_like(env.action_space.low), high=np.ones_like(env.action_space.high))
        else:
            self.need_scale_up = False

        self.nleft = nleft
        self.nright = nright

    def step(self, action):
        new_action = np.zeros_like(action)
        mask1 = (action <= -1. + 1. / self.nleft)
        mask2 = ((-1. + 1. / self.nleft < action) & (action <= 0.))
        mask3 = ((0. < action) & (action <= 1. - 1. / self.nright))
        mask4 = (action > 1. - 1. / self.nright)

        new_action[mask1] = action[mask1] * self.nleft + self.nleft - 1
        new_action[mask2] = 0.
        new_action[mask3] = 0.
        new_action[mask4] = action[mask4] * self.nright - self.nright + 1
        # assert self.nright == self.nleft, "asymmetric redundancy"
        # new_action = np.clip(action * self.nright, -1, 1)
        if self.need_scale_up:
            new_action = (new_action + 1 ) * 0.5 * (self.original_high- self.original_low) + self.original_low

        return self.env.step(new_action)

class DiscreteWrapper(gym.Wrapper):
    """
    Continuous to discrete wrapper.
    refer to https://github.com/tseyde/decqn/blob/master/decqn/wrappers.py
    """

    def __init__(self, env: gym.Env, config={"num_bins":2}):
        super().__init__(env)
        assert isinstance(self.action_space, spaces.Box),"Original Action Space is NOT Box"

        self._num_bins = config["num_bins"]

        self._action_min = self.env.action_space.low
        self._action_max = self.env.action_space.high
        self._action_all = self._get_action_list()
        action_num = len(self._action_all)

        self.env.action_space = spaces.Discrete(action_num)
        self.action_space = self.env.action_space


    def _get_action_list(self):
        act_lim = list(np.linspace(self._action_min, self._action_max, num=self._num_bins).transpose())
        act_per = itertools.product(*act_lim)
        act_list = [np.array(e) for e in act_per]
        return act_list

    def step(self, action):
        new_action = self._action_all[action]
        return self.env.step(new_action)


class DecoupledDiscreteWrapper(gym.Wrapper):
    """
    Continuous to decoupled discrete wrapper.
    refer to https://github.com/tseyde/decqn/blob/master/decqn/wrappers.py
    """

    def __init__(self, env: gym.Env, config={"num_bins":2}):
        super().__init__(env)
        assert isinstance(self.action_space, spaces.Box),"Original Action Space is NOT Box"

        self._num_bins = config["num_bins"]

        self._action_min = self.env.action_space.low
        self._action_max = self.env.action_space.high
        self._action_all = self._get_action_list()
        # action_num = len(self._action_all)
        act_shape = self._action_min.shape[0]
        # 3-dimesional action spaces with 2 bins : [2,2,2]
        MultiDiscrete_space_array = [self._num_bins for _ in range(act_shape)]

        self.env.action_space = spaces.MultiDiscrete(MultiDiscrete_space_array)
        self.action_space = self.env.action_space


    def _get_action_list(self):
        act_list = list(np.linspace(self._action_min, self._action_max, num=self._num_bins).transpose())
        # act_per = itertools.product(*act_lim)
        # act_list = [np.array(e) for e in act_per]
        return act_list

    def step(self, action):
        new_action =  np.take(self._action_all, action)
        return self.env.step(new_action)