import gym
import gym.spaces as spaces
from collections import deque
import copy

from gym.envs.registration import register
from .pong_gym import *

class LazyFrames(object):
    def __init__(self, frames):
        """
        This object ensures that common frames between the observations are
        only stored once. It exists purely to optimize memory usage which can
        be huge for DQN's 1M frames replay buffers. This object should only be
        converted to numpy array before being passed to the model. You'd not
        believe how complex the previous solution was.
        """
        self._frames = frames
        self._out = None

    def _force(self):
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=-1)
            self._frames = None
        return self._out

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]


class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        """Stack the k last frames.
        Returns a lazy array, which is much more memory efficient.
        See Also
        --------
        baselines.common.atari_wrappers.LazyFrames
        """
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = deque([], maxlen=k)
        # shp = env.observation_space.shape
        old_space = copy.deepcopy(env.observation_space)
        # self.observation_space = spaces.Box(low=0, high=255,
        #                                     shape=(shp[2] * k, shp[0], shp[1]),
        #                                     dtype=np.uint8)
        self.observation_space = spaces.Box(np.repeat(old_space.low, k, axis=-1),
                                            np.repeat(old_space.high, k, axis=-1),
                                            dtype=old_space.dtype)

    def reset(self, **kwargs):
        ob = self.env.reset(**kwargs)
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob()

    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, done, info

    def _get_ob(self):
        assert len(self.frames) == self.k
        return np.concatenate(LazyFrames(list(self.frames)), axis=-1)

class FrameStackEnv(gym.Wrapper):
    """
    An environment that stacks frames.

    In the normal case, input arrays are stacked along the
    inner dimension. In the case where concat=False, the
    arrays are put together in a list.

    The stacking is ordered from oldest to newest.

    At the beginning of an episode, the first observation
    is repeated in order to complete the stack.
    """

    def __init__(self, env, k=4, stride=1):
        """
        Create a frame stacking environment.

        Args:
          env: the environment to wrap.
          k: the number of frames to stack.
            This includes the current observation.
          concat: if True, stacked frames are joined
            together along the inner-most dimension.
            If False, the stacked frames are simply put
            together in a list. In this case, a special
            observation space, StackedBoxSpace, is used.
          stride: the temporal stride. A value larger than
            one indicates that frames should be skipped.
        """
        super().__init__(env)
        old_space = env.observation_space
        self.observation_space = spaces.Box(np.repeat(old_space.low, k, axis=-1),
                                            np.repeat(old_space.high, k, axis=-1),
                                            dtype=old_space.dtype)
        self._k = k
        self._stride = stride
        self._history_size = 1 + (k - 1) * stride
        self._history = []

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        self._history = [obs] * self._history_size
        return self._cur_obs()

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self._history.append(obs)
        self._history = self._history[1:]
        return self._cur_obs(), rew, done, info

    def _cur_obs(self):
        return np.concatenate(self._history[::self._stride], axis=-1)


def wrappedPong(wrapper='Unwrapped', max_score=1.0, render=False, p2_speed=0.4):
    env = PongGym(display_screen=render, max_score=max_score, player2_speed_ratio=p2_speed)
    if wrapper == 'Unwrapped':
        pass
    elif wrapper == 'Basic':
        env = PongGym_Basic(env)
    elif wrapper == 'Defensive':
        env = PongGym_Defensive(env)
    elif wrapper == 'ReturnToOpponent':
        env = PongGym_ReturnToOpponent(env)
    elif wrapper == 'MultiStyleBinary':
        env = PongGym_BinaryStyles(env)
    elif wrapper == 'StackedMultiStyleBinary':
        env = FrameStackEnv(PongGym_BinaryStyles(env), 4)
    elif wrapper == 'MultiStyleLinear':
        env = PongGym_LinearStyles(env)
    elif wrapper == 'SetStyle0':
        env = PongGym_SetStyle(env,0)
    elif wrapper == 'SetStyle1':
        env = PongGym_SetStyle(env,1)
    elif wrapper == 'SetStyle0.5':
        env = PongGym_SetStyle(env,0.5)
    elif wrapper == 'StackedSetStyle0':
        env = FrameStackEnv(PongGym_SetStyle(env,0), 4)
    elif wrapper == 'StackedSetStyle1':
        env = FrameStackEnv(PongGym_SetStyle(env,1), 4)
    elif wrapper == 'StackedSetStyle0.5':
        env = FrameStackEnv(PongGym_SetStyle(env,0.5), 4)
    else:
        print(wrapper)
        raise ValueError('Pong wrapper cannot be undefined. To use no wrapper, designate Unwrapped')
    return env

def CatcherReg(render=False):
    env = CatcherGym(display_screen=render, lives=1)
    return env

def PixelcopterReg(render=False):
    env = PixelcopterGym(display_screen=render, lives=1)
    return env

####
register(
    id='PLEPongRotational-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'Unwrapped'}
)

register(
    id='PLEPongRotational-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'Unwrapped', "render":True}
)
####

####
register(
    id='PLEPongBasic-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'Basic'}
)

register(
    id='PLEPongBasic-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'Basic', "render":True}
)
####

####
register(
    id='PLEPongRotationalReturn-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'ReturnToOpponent'}
)

register(
    id='PLEPongRotationalReturn-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'ReturnToOpponent', "render":True}
)
####

####
register(
    id='PLEPongRotationalDefensive-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'Defensive'}
)

register(
    id='PLEPongRotationalDefensive-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'Defensive', "render":True}
)
####

####
register(
    id='PLEPongRotationalBinary-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'MultiStyleBinary'}
)

register(
    id='PLEPongRotationalBinary-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'MultiStyleBinary', "render":True}
)
####

####
register(
    id='PLEPongRotationalBinaryStacked-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'StackedMultiStyleBinary'}
)

register(
    id='PLEPongRotationalBinaryStacked-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'StackedMultiStyleBinary', "render":True}
)
####

####
register(
    id='PLEPongRotationalLinear-norender-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'MultiStyleLinear'}
)

register(
    id='PLEPongRotationalLinear-render-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'MultiStyleLinear', "render":True}
)
####

####
register(
    id='PLEPongRotationalSetStyle0-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'SetStyle0', "render":True}
)

register(
    id='PLEPongRotationalSetStyle0Stacked-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'StackedSetStyle0', "render":True}
)

register(
    id='PLEPongRotationalSetStyle1-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'SetStyle1', "render":True}
)

register(
    id='PLEPongRotationalSetStyle1Stacked-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'StackedSetStyle1', "render":True}
)

register(
    id='PLEPongRotationalSetStyle0_5-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'SetStyle0.5', "render":True}
)

register(
    id='PLEPongRotationalSetStyle0_5Stacked-v0',
    entry_point='pleMod_gym:wrappedPong',
    max_episode_steps=2000,
    kwargs={"wrapper":'StackedSetStyle0.5', "render":True}
)
####