import copy
import itertools
import math
import os
import random
import time
from pathlib import Path

import gym
import metaworld
import mujoco_py
import numpy as np
from gym.envs.registration import register
from lxml import etree
from scipy.spatial.transform import Rotation as R


class MetaWorldEnv(gym.Env):

    def __init__(
        self,
        task: str,
        image_observation: bool = False,
        domain_id: int = 0,
        **kwargs,
    ):
        super().__init__()
        self.task_id = task
        ml1 = metaworld.ML1(task)  # Construct the benchmark, sampling tasks
        env = ml1.train_classes[task]()

        task = ml1.train_tasks[0]
        env.set_task(task)  # Set task
        env.random_init = True
        env._freeze_rand_vec = False
        env._partially_observable = False

        if "window" in task.env_name:
            if "window-open" in task.env_name:
                obj_low = (-0.1, 0.7, 0.16)
                obj_high = (0.4, 0.8, 0.6)
            elif "window-close" in task.env_name:
                obj_low = (-0.3, 0.7, 0.16)
                obj_high = (0.2, 0.8, 0.6)
            env._random_reset_space = gym.spaces.Box(
                np.array(obj_low).astype(np.float32),
                np.array(obj_high).astype(np.float32),
            )
            _id = env.sim.model.site_name2id('goal')
            env.sim.model.site_rgba[_id] = [0, 0.8, 0, 0]
        if "reach" in task.env_name:
            _id = env.sim.model.site_name2id('goal')
            env.sim.model.site_size[_id] = [0.05, 0.05, 0.05]
            env.sim.model.site_rgba[_id] = [0, 0.8, 0, 1]
            env.random_init = False

        env.reset()
        self.env = env

        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.prev_ret = None
        self.max_episode_steps = 200
        self.t = 0
        self.info = {"success": 0}
        self.image_observation = image_observation
        self.domain_id = domain_id

    def step(self, action):
        try:
            obs, rew, done, info = self.env.step(action)
            self.prev_ret = (obs, rew, done, info)
        except ValueError:
            obs, rew, _, info = self.prev_ret
            done = True
        self.t += 1
        if self.t >= self.max_episode_steps:
            done = True
        self.info = info
        obs = self.get_obs()
        return obs, rew, done, info

    def reset(self):
        self.t = 0
        self.info = {"success": 0}

        camera_id = self.sim.model.camera_name2id("corner3")

        euler = [180, 0, -45]
        quat = R.from_euler("xyz", euler, degrees=True).as_quat()

        self.sim.model.cam_pos[camera_id] = np.array([0.0, 1.5, 0.9])
        self.sim.model.cam_quat[camera_id] = np.array(quat)
        self.sim.model.cam_fovy[camera_id] = 60
        self.env.reset()
        self.env.reset_model()
        return self.get_obs()

    def render(self, mode="rgb_array", **kwargs):
        if mode == "rgb_array":
            camera_name = "corner3" if self.domain_id == 0 else "corner"
            if hasattr(kwargs, "resolution"):
                resolution = kwargs["resolution"]
            else:
                resolution = (128, 128)
            return self.env.render(
                offscreen=True,
                camera_name=camera_name,
                resolution=resolution,
            )
        else:
            return self.env.render(**kwargs)

    @property
    def sim(self):
        return self.env.sim

    def get_obs_dict(self):
        return self.env._get_obs_dict()

    def setup_task(self, goal_id, start_id):
        if self.task_id == "reach-v2":
            x = np.random.uniform(-0.24, 0.24)
            y = 0.85
            z = 0.15
            goal_pos = np.array([x, y, z]) + np.random.uniform(
                -1, 1, size=(3, )) * np.array([0.1, 0.0, 0.1])
            self.env.goal = goal_pos
            self.reset()

    def get_obs(self):
        obs = self.env._get_obs()
        hand_pos = obs[:4]
        prev_hand_pos = obs[18:22]
        obs = np.concatenate((hand_pos, prev_hand_pos), axis=-1)
        if self.image_observation:
            image = self.render(resolution=(128, 128))
            return {"state": obs, "image": image}
        else:
            return obs

    def get_success(self, *args, **kwargs):
        return bool(self.info["success"])


class ReachGoalEnv(MetaWorldEnv):

    def __init__(
        self,
        image_observation: bool = False,
        domain_id: int = 0,
        success_thresh: float = 0.05,
        **kwargs,
    ):
        super().__init__(task="reach-v2",
                         image_observation=image_observation,
                         domain_id=domain_id)
        self.success_thresh = success_thresh

    def reset(self):
        self.t = 0
        self.info = {"success": 0}
        self.env.hand_init_pos = np.array([
            0., 0.6, 0.2
        ]) + np.random.uniform(-1, 1, size=(3, )) * np.array([0.3, 0.05, 0.05])

        camera_id = self.sim.model.camera_name2id("corner3")

        euler = [180, 0, -45]
        quat = R.from_euler("xyz", euler, degrees=True).as_quat()

        self.sim.model.cam_pos[camera_id] = np.array([0.0, 1.5, 0.9])
        self.sim.model.cam_quat[camera_id] = np.array(quat)
        self.sim.model.cam_fovy[camera_id] = 60

        target = self.env.goal

        self.env.reset()
        self.env.reset_model()

        self.env.goal = target
        self.env._target_pos = target

        return self.get_obs()

    def setup_task(self, goal_id, start_id):
        self.env.goal = self.goal_id_to_pos(goal_id)
        self.reset()

    def goal_id_to_pos(self, goal_id):
        assert 0 <= goal_id < 4
        x = goal_id * 0.16 - 0.24
        y = 0.85
        z = 0.15
        goal_pos = np.array([
            x, y, z
        ]) + np.random.uniform(-1, 1, size=(3, )) * np.array([0.1, 0.0, 0.1])
        return goal_pos

    def get_success(self, *args, **kwargs):
        obs = self.env._get_obs()
        reward, reach_dist, in_place = self.env.compute_reward(0, obs)
        return reach_dist < self.success_thresh


def convert_path_to_absolute(root, model_dir):
    for asset in root.find("asset").getchildren():
        if relative_path := asset.get("file"):
            absolute_path = os.path.abspath(
                os.path.join(model_dir, relative_path))
            asset.set("file", absolute_path)


class ReachColorEnv(MetaWorldEnv):

    def __init__(
        self,
        image_observation: bool = False,
        domain_id: int = 0,
        success_thresh: float = 0.05,
        num_task_ids: int = 4,
        easy_task: bool = False,
        **kwargs,
    ):
        super().__init__(task="reach-v2",
                         image_observation=image_observation,
                         domain_id=domain_id)
        self.success_thresh = success_thresh
        self.num_task_ids = num_task_ids
        self.easy_task = easy_task
        self.task_id = -1
        self.permutations = np.array(
            list(itertools.permutations(range(num_task_ids))))
        self.permutation_id = -1
        assert num_task_ids <= 4

        if self.easy_task:
            self.site_size = "0.1"
        else:
            self.site_size = "0.05"

        root = etree.fromstring(self.env.sim.model.get_xml())
        model_dir = Path(
            metaworld.__file__).parent / "envs/assets_v2/sawyer_xyz"
        convert_path_to_absolute(root, model_dir)

        for i in range(num_task_ids - 1):
            # Set attributes of new site
            new_site = etree.Element("site")
            new_site.attrib["name"] = f"dummy_goal_{i}"
            new_site.attrib["pos"] = f"0 0 -1"
            new_site.attrib["size"] = self.site_size
            new_site.attrib["rgba"] = "1 0 0 1"
            new_site.attrib["type"] = "sphere"

            # Add new site
            worldbody = root.find(".//worldbody")
            worldbody.append(new_site)

        # Generate new XML
        new_xml = etree.tostring(root, pretty_print=True).decode("ascii")

        # Load new model
        self.env.model = mujoco_py.load_model_from_xml(new_xml)
        self.env.sim = mujoco_py.MjSim(self.env.model)
        self.env.data = self.env.sim.data
        self.env.init_qpos = self.env.sim.data.qpos.ravel().copy()
        self.env.init_qvel = self.env.sim.data.qvel.ravel().copy()
        self.env.random_init = False
        self.env.hand_init_pos = np.array([0., 0.3, 0.2])

    def reset(self):
        self.t = 0
        self.info = {"success": 0}
        self.env.hand_init_pos = np.array([
            0., 0.4, 0.3
        ]) + np.random.uniform(-1, 1, size=(3, )) * np.array([0.1, 0.05, 0.05])

        camera_id = self.sim.model.camera_name2id("corner3")

        euler = [180, 0, -45]
        quat = R.from_euler("xyz", euler, degrees=True).as_quat()

        self.sim.model.cam_pos[camera_id] = np.array([0.0, 1.5, 0.9])
        self.sim.model.cam_quat[camera_id] = np.array(quat)
        self.sim.model.cam_fovy[camera_id] = 60

        target = self.env.goal

        self.env.reset()
        self.env.reset_model()

        self.env.goal = target
        self.env._target_pos = target

        return self.get_obs()

    @staticmethod
    def _distance(pos1, pos2):
        return np.linalg.norm(pos1 - pos2)

    def _hit_any_ball(self, pos: np.ndarray) -> bool:
        for i in range(self.num_task_ids):
            center_pos = self.pos_id_to_goal_pos(i)
            if self._distance(pos, center_pos) < self.success_thresh * 0.8:
                return True

        return False

    def step(self, action):
        obs, rew, done, info = super().step(action)
        if self._hit_any_ball(pos=self.env.tcp_center):
            done = True

        return obs, rew, done, info

    def setup_task(self, goal_id, start_id, permutation_id=None):
        assert 0 <= goal_id < self.num_task_ids
        self.task_id = goal_id
        if permutation_id is None:
            self.permutation_id = np.random.randint(len(self.permutations))
        else:
            self.permutation_id = permutation_id

        pos_id = np.where(
            self.permutations[self.permutation_id] == goal_id)[0][0]
        goal_rgba = self.goal_id_to_rgba(goal_id)

        site_id = self.env.sim.model.site_name2id("goal")
        self.env.sim.model.site_rgba[site_id] = goal_rgba
        self.env.sim.model.site_size[site_id] = float(self.site_size)
        self.env.goal = self.pos_id_to_goal_pos(pos_id)
        self.env.sim.model.site_pos[site_id] = self.env.goal

        dummy_colors = copy.deepcopy(
            list(self.permutations[self.permutation_id]))
        dummy_colors.remove(goal_id)
        dummy_idx = 0
        for i in range(self.num_task_ids):
            if i == pos_id:
                continue
            site_id = self.env.sim.model.site_name2id(
                f"dummy_goal_{dummy_idx}")
            pos = self.pos_id_to_goal_pos(i)
            self.env.sim.model.site_rgba[site_id] = self.goal_id_to_rgba(
                dummy_colors[dummy_idx])
            self.env.sim.model.site_pos[site_id] = pos
            dummy_idx += 1

        self.reset()

    def goal_id_to_rgba(self, goal_id):
        return [[0, 0.8, 0, 1], [0, 0, 0.8, 1], [0.8, 0, 0, 1],
                [0.8, 0.8, 0, 1]][goal_id]

    def pos_id_to_goal_pos(self, pos_id):
        if self.num_task_ids == 2:
            width = 0.48
            offset = -0.24
        elif self.num_task_ids == 3:
            width = 0.36
            offset = -0.36
        elif self.num_task_ids == 4:
            width = 0.24
            offset = -0.45
        else:
            raise NotImplementedError
        return np.array([pos_id * width + offset, 0.85, 0.15])

    def get_success(self, *args, **kwargs):
        obs = self.env._get_obs()
        reward, reach_dist, in_place = self.env.compute_reward(0, obs)
        return reach_dist < self.success_thresh

    def get_permutation(self):
        assert self.permutation_id != -1
        return self.permutation_id

    def set_permutation(self, permutation_id):
        self.permutation_id = permutation_id
        self.setup_task(self.task_id, 0, permutation_id=permutation_id)


class WindowCloseTaskEnv(MetaWorldEnv):

    def __init__(
        self,
        image_observation: bool = False,
        domain_id: int = 0,
        num_task_ids: int = 4,
        **kwargs,
    ):
        super().__init__(
            task="window-close-v2",
            image_observation=image_observation,
            domain_id=domain_id,
        )
        self.num_task_ids = num_task_ids
        self.task_ids = -1

    def setup_task(self, goal_id, start_id):
        assert 0 <= goal_id < self.num_task_ids
        self.task_id = goal_id

        x_range = np.array([-0.5, 0.2]) + (goal_id % 2) * 0.3
        z_range = np.array([0.2, 0.4]) + (goal_id // 2) * 0.2

        obj_low = np.array([x_range[0], 0.7, z_range[0]])
        obj_high = np.array([x_range[1], 0.8, z_range[1]])
        self.env._random_reset_space = gym.spaces.Box(
            np.array(obj_low).astype(np.float32),
            np.array(obj_high).astype(np.float32),
        )


MW_TASKS = [
    "push-v2",
    "push-wall-v2",
    "reach-v2",
    "reach-wall-v2",
    "sweep-v2",
    "window-open-v2",
    "window-close-v2",
]
for task in MW_TASKS:
    # print(f"Registered {task}.")
    register(
        id=task,
        entry_point=MetaWorldEnv,
        max_episode_steps=200,
        kwargs={
            "task": task,
            "image_observation": False,
        },
    )

register(
    id="reach-goal-v2",
    entry_point=ReachGoalEnv,
    max_episode_steps=200,
    kwargs={
        "image_observation": False,
    },
)
MW_TASKS += ["reach-goal-v2"]

register(
    id="reach-color-v2",
    entry_point=ReachColorEnv,
    max_episode_steps=200,
    kwargs={
        "image_observation": False,
    },
)
MW_TASKS += ["reach-color-v2"]

register(
    id="reach-color_simple_2-v2",
    entry_point=ReachColorEnv,
    max_episode_steps=200,
    kwargs={
        "image_observation": False,
        "num_task_ids": 2,
        "easy_task": True,
        "success_thresh": 0.1,
    },
)
MW_TASKS += ["reach-color_simple_2-v2"]

register(
    id="reach-color_simple_3-v2",
    entry_point=ReachColorEnv,
    max_episode_steps=200,
    kwargs={
        "image_observation": False,
        "num_task_ids": 3,
        "easy_task": True,
        "success_thresh": 0.1,
    },
)
MW_TASKS += ["reach-color_simple_3-v2"]

register(
    id="window-close_4-v2",
    entry_point=WindowCloseTaskEnv,
    max_episode_steps=200,
    kwargs={
        "image_observation": False,
        "num_task_ids": 4,
    },
)
MW_TASKS += ["window-close_4-v2"]
