import gym
import numpy as np
from furniture_bench.perception.image_utils import resize, resize_crop
from ml_collections import ConfigDict

from .core import Env


class FurnitureBench(Env):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()

        config.action_repeat = 1
        config.randomness = "low"
        config.image_size = (224, 224)
        config.max_env_steps = 600
        config.record_video = True
        config.record_every = 2
        config.gpu_id = 0

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(
        self, name, furniture_name, update=None,
    ):
        import furniture_bench

        self._config = self.get_default_config(update)
        self._env = gym.make(
            name,
            furniture=furniture_name,
            headless=True,
            resize_img=False,
            np_step_out=True,
            concat_robot_state=True,
            randomness=self._config.randomness,
            max_env_steps=self._config.max_env_steps,
            compute_device_id=self._config.gpu_id,
            graphics_device_id=self._config.gpu_id,
        )

        self._prev_obs = None
        self._max_episode_steps = self._env.furniture.max_env_steps

        self._episode_index = 0
        self._i = 0

        self._record_image_key = "color_image2"
        self._recorded_images = []

        self._record_current_episode = (
            self._config.record_video and self._episode_index % self._config.record_every == 0
        )

    @property
    def obs_space(self):
        spaces = gym.spaces.Dict(
            {
                "color_image1": gym.spaces.Box(low=0, high=255, shape=(*self._config.img_size, 3)),
                "color_image2": gym.spaces.Box(low=0, high=255, shape=(*self._config.img_size, 3)),
            }
        )
        return spaces
        # return self._env.observation_space["proprio"]

    @property
    def observation_space(self):
        return self._env.observation_space

    @property
    def action_space(self):
        action = self._env.action_space
        return action

    @property
    def act_space(self):
        return gym.spaces.Box(low=-1.0, high=1.0, shape=(self._env.pose_dim + 1,))

    def step(self, action):
        assert np.isfinite(action).all(), action
        try:
            reward = 0.0
            for _ in range(self._config.action_repeat):
                self._i += 1
                obs, rew, done, info = self._env.step(action)
                reward += rew or 0.0
            self._prev_obs = obs
        except ValueError:
            done = True
            obs = self._prev_obs
            reward = 0.0
            info = {"success": False, "action_success": False}

        obs["color_image1"] = resize(obs["color_image1"].squeeze())
        obs["color_image2"] = resize_crop(obs["color_image2"].squeeze())
        self.record(obs)
        if done and self._record_current_episode:
            vid = np.array(self._recorded_images)
        else:
            vid = None
        info = {"vid": vid, "episode_len": self._i}
        info["vid"] = vid
        return obs, reward, done, info

    def record(self, obs):
        self._recorded_images.append(obs[self._record_image_key])

    def reset(self):
        self._i = 0
        self._episode_index += 1
        self._record_current_episode = (
            self._config.record_video and self._episode_index % self._config.record_every == 0
        )

        obs = self._env.reset()
        obs["color_image1"] = resize(obs["color_image1"].squeeze())
        obs["color_image2"] = resize_crop(obs["color_image2"].squeeze())

        self._recorded_images.clear()
        self._prev_obs = None
        self.record(obs)
        return obs
        # # return state
        # obs = {
        #     "reward": 0.0,
        #     "is_first": True,
        #     "is_last": False,
        #     "is_terminal": False,
        #     "image": self._env.sim.render(*self._size, mode="offscreen", camera_name=self._camera),
        #     "state": state,
        #     "success": False,
        # }


if __name__ == "__main__":
    env = FurnitureBench("FurnitureSim-v0", "one_leg", None)
    init = env.reset()
    timestep = 0
    for _ in range(630):
        timestep += 1
        res, rew, done, info = env.step(env.action_space.sample())
        print(timestep)
        if done:
            break
