from collections import deque

import mujoco_py
import numpy as np
import torch
from PIL import Image
from gym import Env
from gym.spaces import Box

from rlkit.envs.proxy_env import ProxyEnv


class ImageMujocoEnv(ProxyEnv, Env):
    def __init__(self,
                 wrapped_env,
                 imsize=32,
                 keep_prev=0,
                 init_camera=None,
                 camera_name=None,
                 transpose=False,
                 grayscale=False,
                 normalize=False,
                 ):
        super().__init__(wrapped_env)

        self.imsize = imsize
        if grayscale:
            self.image_length = self.imsize * self.imsize
        else:
            self.image_length = 3 * self.imsize * self.imsize
        # This is torch format rather than PIL image
        self.image_shape = (self.imsize, self.imsize)
        # Flattened past image queue
        self.history_length = keep_prev + 1
        self.history = deque(maxlen=self.history_length)
        # init camera
        if init_camera is not None:
            sim = self._wrapped_env.sim
            viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1)
            init_camera(viewer.cam)
            sim.add_render_context(viewer)
        self.camera_name = camera_name  # None means default camera
        self.transpose = transpose
        self.grayscale = grayscale
        self.normalize = normalize
        self._render_local = False

        self.observation_space = Box(low=0.0,
                                     high=1.0,
                                     shape=(
                                         self.image_length * self.history_length,))

    def step(self, action):
        # image observation get returned as a flattened 1D array
        true_state, reward, done, info = super().step(action)

        observation = self._image_observation()
        self.history.append(observation)
        history = self._get_history().flatten()
        full_obs = self._get_obs(history, true_state)
        return full_obs, reward, done, info

    def reset(self, **kwargs):
        true_state = super().reset(**kwargs)
        self.history = deque(maxlen=self.history_length)

        observation = self._image_observation()
        self.history.append(observation)
        history = self._get_history().flatten()
        full_obs = self._get_obs(history, true_state)
        return full_obs

    def get_image(self):
        """TODO: this should probably consider history"""
        return self._image_observation()

    def _get_obs(self, history_flat, true_state):
        # adds extra information from true_state into to the image observation.
        # Used in ImageWithObsEnv.
        return history_flat

    def _image_observation(self):
        # returns the image as a torch format np array
        image_obs = self._wrapped_env.sim.render(width=self.imsize,
                                                 height=self.imsize,
                                                 camera_name=self.camera_name)
        if self._render_local:
            cv2.imshow('env', image_obs)
            cv2.waitKey(1)
        if self.grayscale:
            image_obs = Image.fromarray(image_obs).convert('L')
            image_obs = np.array(image_obs)
        if self.normalize:
            image_obs = image_obs / 255.0
        if self.transpose:
            image_obs = image_obs.transpose()
        return image_obs

    def _get_history(self):
        observations = list(self.history)

        obs_count = len(observations)
        for _ in range(self.history_length - obs_count):
            dummy = np.zeros(self.image_shape)
            observations.append(dummy)
        return np.c_[observations]

    def retrieve_images(self):
        # returns images in unflattened PIL format
        images = []
        for image_obs in self.history:
            pil_image = self.torch_to_pil(torch.Tensor(image_obs))
            images.append(pil_image)
        return images

    def split_obs(self, obs):
        # splits observation into image input and true observation input
        imlength = self.image_length * self.history_length
        obs_length = self.observation_space.low.size
        obs = obs.view(-1, obs_length)
        image_obs = obs.narrow(start=0,
                               length=imlength,
                               dimension=1)
        if obs_length == imlength:
            return image_obs, None

        fc_obs = obs.narrow(start=imlength,
                            length=obs.shape[1] - imlength,
                            dimension=1)
        return image_obs, fc_obs

    def enable_render(self):
        self._render_local = True


