import gym
import numpy as np
from gym import spaces

from .lift import make_lift_env, set_lift_task_id
from .stack import make_stack_env, set_stack_task_id


def make_robot_manipulation_env(env_name,
                                robots,
                                env_kwargs,
                                image_observation,
                                z_offset=0.0):
    if env_name == "Lift":
        return make_lift_env(robots, env_kwargs, image_observation, z_offset)
    elif env_name == "Stack":
        return make_stack_env(robots, env_kwargs, image_observation)
    else:
        raise ValueError(f"{env_name} is not valid env_name.")


def set_task_id(env_name, env, task_id):
    if env_name == "Lift":
        return set_lift_task_id(env, task_id)
    elif env_name == "Stack":
        return set_stack_task_id(env, task_id)
    else:
        raise ValueError(f"{env_name} is not valid env_name.")


class RobotManipulationEnv(gym.Env):

    def __init__(
        self,
        env_name,
        robots,
        keys=None,
        image_observation=False,
        env_kwargs=None,
        z_offset=0.0,
        **kwargs,
    ):
        self.env = make_robot_manipulation_env(
            env_name=env_name,
            robots=robots,
            image_observation=image_observation,
            env_kwargs=env_kwargs,
            z_offset=z_offset,
        )
        self.set_table_offset(z_offset)

        self.env_name = env_name
        self.image_observation = image_observation

        # Create name for gym
        robots = "".join(
            [type(robot.robot_model).__name__ for robot in self.env.robots])
        self.name = robots + "_" + type(self.env).__name__

        if keys is None:
            keys = []
            # Add object obs if requested
            if self.env.use_object_obs:
                keys += ["object-state"]
            # Add image obs if requested
            if self.env.use_camera_obs:
                keys += [
                    f"{cam_name}_image" for cam_name in self.env.camera_names
                ]
            # Iterate over all robots to add to state
            for idx in range(len(self.env.robots)):
                keys += ["robot{}_proprio-state".format(idx)]
        self.keys = keys

        # Gym specific attributes
        self.env.spec = None
        self.metadata = None

        # set up observation and action spaces
        obs_dict = self.env.reset()
        self.modality_dims = {key: obs_dict[key].shape for key in self.keys}
        obs = self._get_obs_from_obs_dict(obs_dict)
        if image_observation:
            self.obs_dim = obs["state"].shape
        else:
            self.obs_dim = obs.shape
        high = np.inf * np.ones(self.obs_dim).astype(np.float32)
        low = -high
        self.observation_space = spaces.Box(low=low, high=high)
        low, high = self.env.action_spec
        self.action_space = spaces.Box(low=low, high=high)

        # self.position_noise_range = np.array([-0.02, 0.02])
        self.position_noise_range = np.array([-0.001, 0.001])
        self.obs_dict = None

    def step(self, action):
        obs_dict, reward, done, info = self.env.step(action)
        self.obs_dict = obs_dict
        obs = self._get_obs_from_obs_dict(obs_dict)
        return obs, reward, done, info

    def reset(self, seed=None, return_info=False, options=False):
        obs_dict = self.env.reset()
        self.obs_dict = obs_dict
        obs = self._get_obs_from_obs_dict(obs_dict)
        return obs

    def _get_obs_from_obs_dict(self, obs_dict):
        if self.image_observation:
            obs = {}
            obs["image"] = self.get_image_obs(obs_dict)
            obs["state"] = obs_dict["robot0_proprio-state"]
        else:
            obs = np.concatenate([obs_dict[key] for key in self.keys])
        return obs

    def get_image_obs(self, obs_dict):
        agentview_image = obs_dict["agentview_image"]
        sideview_image = obs_dict["sideview_image"]
        image_obs = np.concatenate([agentview_image, sideview_image], axis=2)
        return image_obs

    def setup_task(self, goal_id: int, start_id: int):
        return set_task_id(self.env_name, self, goal_id)

    @staticmethod
    def task_id_to_pos(task_id):
        z_i = task_id // 9
        xy_i = task_id % 9
        x = (xy_i // 3) * 0.13 - 0.13
        y = (xy_i % 3) * 0.13 - 0.13
        z = (z_i - 1) * 0.1
        return np.array([x, y, z])

    def set_table_offset(self, table_offset):
        self.env.table_offset = np.array([0, 0, 0.8 + table_offset])
        self.env._load_model()

    def render(self, mode="rgb_array", **kwargs):
        assert mode == "rgb_array"
        if self.obs_dict is None:
            raise ValueError
        agentview_image = self.obs_dict["agentview_image"][::-1]
        sideview_image = self.obs_dict["sideview_image"][::-1]
        image = np.concatenate([agentview_image, sideview_image], axis=0)
        return image

    def get_obs(self):
        obs = self._get_obs_from_obs_dict(self.obs_dict)
        return obs

    def goal_id_to_task_id_list(self, goal_id: int):
        return [goal_id]

    def get_success(self, *args, **kwargs):
        return self.env._check_success()

    @property
    def sim(self):
        return self.env.sim
