"""
Environment wrapper for Robomimic environments with image observations.

Also return done=False since we do not terminate episode early.

Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/env/robomimic/robomimic_image_wrapper.py

"""

from debug import debug_print
import numpy as np
import gym
from gym import spaces
import imageio


class RobomimicImageWrapper(gym.Env):
    def __init__(
        self,
        env,
        args,
        normalization_path=None,
        low_dim_keys=[
            "robot0_eef_pos",
            "robot0_eef_quat",
            "robot0_gripper_qpos",
        ],
        image_keys=[
            "agentview_image",
            "robot0_eye_in_hand_image",
        ],
        clamp_obs=False,
        init_state=None,
        render_hw=(256, 256),
        render_camera_name="agentview",
    ):
        self.env = env
        self.init_state = init_state
        self.cnt_episode = 0
        self.has_reset_before = False
        self.render_hw = render_hw
        self.render_camera_name = render_camera_name
        self.video_writer = None
        self.clamp_obs = clamp_obs
        self.args = args

        # set up normalization
        self.normalize = normalization_path is not None
        if self.normalize:
            normalization = np.load(normalization_path)
            normalization_low_dim = np.load(normalization_path.replace("/normalization.npz", "-img/normalization.npz"))
            self.obs_min = normalization["obs_min"]
            self.obs_max = normalization["obs_max"]
            self.low_dim_min = normalization_low_dim["obs_min"]
            self.low_dim_max = normalization_low_dim["obs_max"]
            self.action_min = normalization_low_dim["action_min"]
            self.action_max = normalization_low_dim["action_max"]

        # setup spaces
        low = np.full(env.action_dimension, fill_value=-1)
        high = np.full(env.action_dimension, fill_value=1)
        self.action_space = [[gym.spaces.Box(
            low=low,
            high=high,
            shape=low.shape,
            dtype=low.dtype,
        )]]
        self.low_dim_keys = low_dim_keys
        self.image_keys = image_keys
        self.obs_keys = low_dim_keys + image_keys
        self._step = 0
        self._finish_time = 0
        obs_example_full = self.env.get_observation()
        obs_example_full = self.get_observation(obs_example_full)
        low = np.full_like(obs_example_full["observation"], fill_value=0)
        high = np.full_like(obs_example_full["observation"], fill_value=256)
        self.observation_space = [spaces.Box(
            low=low,
            high=high,
            shape=low.shape,
            dtype=np.float32,
        )]
        low = np.full_like(obs_example_full["state"], fill_value=-1)
        high = np.full_like(obs_example_full["state"], fill_value=1)
        self.share_observation_space = [spaces.Box(
            low=low,
            high=high,
            shape=low.shape,
            dtype=np.float32,
        )]
        
        # debug_print(self.observation_space)
        # debug_print(self.share_observation_space)
        # observation_space = spaces.Dict()
        # for key, value in shape_meta["obs"].items():
        #     shape = value["shape"]
        #     if key.endswith("rgb"):
        #         min_value, max_value = 0, 1
        #     elif key.endswith("state"):
        #         min_value, max_value = -1, 1
        #     else:
        #         raise RuntimeError(f"Unsupported type {key}")
        #     this_space = spaces.Box(
        #         low=min_value,
        #         high=max_value,
        #         shape=shape,
        #         dtype=np.float32,
        #     )
        #     observation_space[key] = this_space
        # self.observation_space = observation_space

    def normalize_obs(self, obs):
        obs = 2 * (
            (obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5
        )  # -> [-1, 1]
        if self.clamp_obs:
            obs = np.clip(obs, -1, 1)
        return obs
    
    def normalize_low_dim_obs(self, obs):
        obs = 2 * (
            (obs - self.low_dim_min) / (self.low_dim_max - self.low_dim_min + 1e-6) - 0.5
        )  # -> [-1, 1]
        if self.clamp_obs:
            obs = np.clip(obs, -1, 1)
        return obs
    
    def unnormalize_action(self, action):
        action = (action + 1) / 2  # [-1, 1] -> [0, 1]
        return action * (self.action_max - self.action_min) + self.action_min

    def get_observation(self, raw_obs):
        obs = {"rgb": None, "state": None}  # stack rgb if multiple cameras
        obs["observation"] = None
        for key in self.obs_keys:
            if key in self.image_keys:
                if obs["rgb"] is None:
                    obs["rgb"] = raw_obs[key]
                else:
                    obs["rgb"] = np.concatenate(
                        [obs["rgb"], raw_obs[key]], axis=0
                    )  # C H W
            else:
                if obs["state"] is None:
                    obs["state"] = raw_obs[key]
                else:
                    obs["state"] = np.concatenate([obs["state"], raw_obs[key]], axis=-1)
                if key != "object":
                    if obs["observation"] is None:
                        obs["observation"] = raw_obs[key]
                    else:
                        obs["observation"] = np.concatenate([obs["observation"], raw_obs[key]], axis=-1)
        if self.normalize:
            obs["state"] = self.normalize_obs(obs["state"])
            obs["observation"] = self.normalize_low_dim_obs(obs["observation"])
        obs["rgb"] *= 255  # [0, 1] -> [0, 255], in float6
        # debug_print(obs["rgb"].shape)
        obs["rgb"] = obs["rgb"].astype(np.float32).flatten()
        obs["observation"] = np.concatenate([obs["rgb"], obs["observation"]], axis=-1)
        # debug_print(obs["rgb"].shape)
        return obs

    def seed(self, seed=None):
        self.seed = seed
        if seed is not None:
            np.random.seed(seed=seed)
        else:
            np.random.seed()

    def reset(self, options={}, ok=True, **kwargs):
        """Ignore passed-in arguments like seed"""
        self.cnt_episode += 1
        self._step = 0
        self._ok = False
        # Close video if exists
        if self.video_writer is not None:
            self.video_writer.close()
            self.video_writer = None

        # Start video if specified
        if self.args.render and ok:
            options['video_path']=f'video{self.seed//1000}_{self.cnt_episode}.mp4'
        if "video_path" in options and ok:
            self.video_writer = imageio.get_writer(options["video_path"], fps=30)

        # Call reset
        new_seed = options.get(
            "seed", None
        )  # used to set all environments to specified seeds
        if self.init_state is not None:
            if not self.has_reset_before:
                # the env must be fully reset at least once to ensure correct rendering
                self.env.reset()
                self.has_reset_before = True

            # always reset to the same state to be compatible with gym
            raw_obs = self.env.reset_to({"states": self.init_state})
        elif new_seed is not None:
            self.seed(seed=new_seed)
            raw_obs = self.env.reset()
        else:
            # random reset
            raw_obs = self.env.reset()
        obs = self.get_observation(raw_obs)
        # debug_print("obs", obs["observation"].shape, "share", obs["state"].shape)
        return [obs["observation"]], [obs["state"]], np.zeros(1)
    
    def step(self, action):
        self._step += 1
        action = action[0]
        if self.normalize:
            action = self.unnormalize_action(action)
        raw_obs, reward, done, info = self.env.step(action)
        obs = self.get_observation(raw_obs)
        step = self._step
        if not self._ok and reward > 0:
            self._ok = True
            print('reward epi:', self._step)
            self._finish_time = self._step
        rok = self._ok
        if self._step >= self.args.max_episode_length:# or self._ok:
            # reward = self.args.max_episode_length - self._step
            if not self._ok:
                self._finish_time = self._step
            obs_img, obs_state, _ = self.reset(ok=False)
            obs = {"observation": obs_img[0], "state": obs_state[0]}
            done = True
            # print('reset epi:', self._step)

        if done:
            info.update({"episode_length": self._finish_time,
                         "episode_return": 1 if rok else 0,
                         "episode_reward": reward})
        # render if specified
        if self.video_writer is not None:
            # video_img = raw_obs["agentview_image"].reshape(3, 96, 96).transpose(1, 2, 0).astype(np.uint8)
            # print(video_img.shape)
            video_img = self.render(mode="rgb_array")
            # print(video_img.shape)
            self.video_writer.append_data(video_img)
        
        # debug_print(obs)
        # debug_print("obs", obs["observation"].shape, "share", obs["state"].shape)

        return [obs["observation"]], [obs["state"]], [np.array([reward])], [done], [info], [np.zeros(1)]

    def render(self, mode="rgb_array"):
        h, w = self.render_hw
        return self.env.render(
            mode=mode,
            height=h,
            width=w,
            camera_name=self.render_camera_name,
        )


if __name__ == "__main__":
    import os
    from omegaconf import OmegaConf
    import json

    os.environ["MUJOCO_GL"] = "egl"

    cfg = OmegaConf.load("cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml")
    shape_meta = cfg["shape_meta"]

    import robomimic.utils.env_utils as EnvUtils
    import robomimic.utils.obs_utils as ObsUtils
    import matplotlib.pyplot as plt

    wrappers = cfg.env.wrappers
    obs_modality_dict = {
        "low_dim": (
            wrappers.robomimic_image.low_dim_keys
            if "robomimic_image" in wrappers
            else wrappers.robomimic_lowdim.low_dim_keys
        ),
        "rgb": (
            wrappers.robomimic_image.image_keys
            if "robomimic_image" in wrappers
            else None
        ),
    }
    if obs_modality_dict["rgb"] is None:
        obs_modality_dict.pop("rgb")
    ObsUtils.initialize_obs_modality_mapping_from_dict(obs_modality_dict)

    with open(cfg.robomimic_env_cfg_path, "r") as f:
        env_meta = json.load(f)
    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False,
        render_offscreen=False,
        use_image_obs=True,
    )
    env.env.hard_reset = False

    wrapper = RobomimicImageWrapper(
        env=env,
        shape_meta=shape_meta,
        image_keys=env_meta["env_kwargs"]["camera_names"],
    )
    wrapper.seed(0)
    obs = wrapper.reset()
    print(obs.keys())
    img = wrapper.render()
    wrapper.close()
    plt.imshow(img)
    plt.savefig("test.png")
