from collections import deque
import numpy as np
import torch
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)


class RenderObsWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.env.reset()
        obs = self.env.render(mode='rgb_array')
        self.observation_space = spaces.Box(low=0, high=255, shape=obs.shape, dtype=np.uint8)

    def observation(self, observation):
        return self.env.render(mode='rgb_array')


class PreprocessObsWrapper(gym.ObservationWrapper):
    def __init__(self, env, img_w, img_h, grayscale):
        gym.ObservationWrapper.__init__(self, env)
        self.img_w = img_w
        self.img_h = img_h
        self.grayscale = grayscale
        num_channels = 1 if grayscale else 3
        obs_shape = [num_channels, img_w, img_h]
        self.observation_space = spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        if self.grayscale:
            observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        observation = cv2.resize(observation, (self.img_w, self.img_h), interpolation=cv2.INTER_AREA)
        if self.grayscale:
            return observation[None, ...]
        return observation.transpose(2, 0, 1)


# taken from `https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/atari_wrappers.py#L229`
class LazyFrames(object):
    def __init__(self, frames):
        """This object ensures that common frames between the observations are only stored once.
        It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
        buffers.
        This object should only be converted to numpy array before being passed to the model.
        You'd not believe how complex the previous solution was."""
        self._frames = frames
        self._out = None

    def _force(self):
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=0)
            self._frames = None
        return self._out

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]

    def count(self):
        frames = self._force()
        return frames.shape[frames.ndim - 1]

    def frame(self, i):
        return self._force()[..., i]


# a modification of `https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/atari_wrappers.py#L188`
class FrameStack(gym.Wrapper):
    def __init__(self, env, num_stack):
        gym.Wrapper.__init__(self, env)
        assert len(self.observation_space.shape) == 3
        self.num_stack = num_stack
        self.frames = deque([], maxlen=num_stack)
        low = np.repeat(self.observation_space.low, num_stack, axis=0)
        high = np.repeat(self.observation_space.high, num_stack, axis=0)
        self.observation_space = spaces.Box(low=low, high=high, dtype=self.observation_space.dtype)

    def reset(self):
        ob = self.env.reset()
        for _ in range(self.num_stack):
            self.frames.append(ob)
        return self._get_ob()

    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, done, info

    def _get_ob(self):
        assert len(self.frames) == self.num_stack
        return LazyFrames(list(self.frames))


class DiscretizeActionsWrapper(gym.Wrapper):
    def __init__(self, env, num_actions, verbose=True):
        gym.Wrapper.__init__(self, env)
        assert isinstance(
            env.action_space, spaces.Box
        ), 'expected Box action space, got {}'.format(type(env.action_space))
        self.old_action_space = self.action_space
        self.action_space = spaces.Discrete(num_actions)
        self.num_actions = num_actions
        self.action_list = [self.old_action_space.sample() for _ in range(num_actions)]
        if verbose:
            print('[DiscretizeActionsWrapper]', self.old_action_space, '-->', self.action_list)

    def step(self, a):
        assert 0 <= a < self.num_actions
        old_a = self.action_list[a]
        return self.env.step(old_a)


class Buffer:
    def __init__(self, max_size):
        self.max_size = max_size
        self.list = []
        self.i = 0

    def push(self, x):
        if len(self.list) < self.max_size:
            self.list.append(x)
        else:
            self.list[self.i % self.max_size] = x
            self.i += 1

    def tolist(self):
        return self.list


class ReplayBufferWrapper(gym.Wrapper):
    def __init__(self, env, buffer_size):
        gym.Wrapper.__init__(self, env)
        self.buffer_size = buffer_size
        self.buffers = {}
        self.s = None

    def reset(self, **kwargs):
        self.s = self.env.reset(**kwargs)
        return self.s

    def step(self, a):
        s_, reward, done, info = self.env.step(a)
        if a not in self.buffers:
            self.buffers[a] = Buffer(self.buffer_size)
        self.buffers[a].push([self.s, s_])
        self.s = s_
        info['buffers'] = self.buffers
        return s_, reward, done, info

class ReplayBufferAutoencoderWrapper(gym.Wrapper):
    def __init__(self, env, buffer_size):
        gym.Wrapper.__init__(self, env)
        self.buffer = Buffer(buffer_size)
        self.s = None

    def reset(self, **kwargs):
        self.s = self.env.reset(**kwargs)
        return self.s

    def step(self, a):
        s_, reward, done, info = self.env.step(a)
        self.buffer.push(s_)
        info['buffers'] = self.buffer
        return s_, reward, done, info

class TransitionBufferWrapper(gym.Wrapper):
    def __init__(self, env, buffer_size):
        gym.Wrapper.__init__(self, env)
        self.buffer = Buffer(buffer_size)
        self.s = None

    def reset(self, **kwargs):
        self.s = self.env.reset(**kwargs)
        return self.s

    def step(self, a):
        s_, reward, done, info = self.env.step(a)
        self.buffer.push([self.s, a, s_])
        self.s = s_
        info['buffers'] = self.buffer
        return s_, reward, done, info


class EncodeObsWrapper(gym.ObservationWrapper):
    def __init__(self, env, enc, code_size):
        gym.ObservationWrapper.__init__(self, env)
        self.enc = enc
        self.code_size = code_size
        self.device = list(enc.parameters())[0].device
        self.observation_space = spaces.Box(low=-100 * np.ones(code_size), high=+100 * np.ones(code_size), dtype=np.float32)
        # TODO: can this be improved? 100 is arbitrary...

    def observation(self, observation):
        self.enc.eval()
        with torch.no_grad():
            x = torch.Tensor(observation).to(self.device)[None, ...]
            x = x.float()/255.0
            return self.enc(x).cpu().numpy()


# taken from `https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py`
class RunningMeanStd:
    def __init__(self, shape, epsilon=1e-4):
        """
        Calulates the running mean and std of a data stream
        https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
        :param epsilon: helps with arithmetic issues
        :param shape: the shape of the data stream's output
        """
        self.mean = np.zeros(shape, np.float64)
        self.var = np.ones(shape, np.float64)
        self.count = epsilon

    def update(self, arr):
        batch_mean = np.mean(arr, axis=0)
        batch_var = np.var(arr, axis=0)
        batch_count = arr.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / tot_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
        new_var = m_2 / (self.count + batch_count)

        new_count = batch_count + self.count

        self.mean = new_mean
        self.var = new_var
        self.count = new_count


class NormalizeObsWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.run_avg = RunningMeanStd(shape=self.observation_space.low.shape)

    def observation(self, observation):
        self.run_avg.update(observation)
        return (observation - self.run_avg.mean) / np.sqrt(self.run_avg.var)