# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""USD exporter."""

import os
from typing import List, Optional, Tuple, Union

import mujoco
import numpy as np

try:
    import pxr
    from pxr import Sdf, Usd, UsdGeom
except ImportError:
    assert False, "Please install usd-core package. You can install it using 'pip install usd-core'."
import scipy
import termcolor
import tqdm
from mujoco import _enums, _functions, _structs
from PIL import Image as im
from PIL import ImageOps

import robosuite.utils.usd.camera as camera_module
import robosuite.utils.usd.lights as light_module
import robosuite.utils.usd.objects as object_module
import robosuite.utils.usd.shapes as shapes_module
from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER

PRIMARY_CAMERA_NAME = "primary_camera"

if mujoco.__version__[0] == "3" and int(mujoco.__version__[2]) >= 2:
    ROBOSUITE_DEFAULT_LOGGER.warning(
        "If using later versions of mujoco, please use the exporter in the mujoco repository: "
        + "https://github.com/google-deepmind/mujoco/blob/main/python/mujoco/usd/exporter.py"
    )
    exit(0)


class USDExporter:
    """MuJoCo to USD exporter for porting scenes to external renderers."""

    def __init__(
        self,
        model: mujoco.MjModel,
        max_geom: int = 10000,
        output_directory_name: str = "robosuite_usd",
        light_intensity: int = 10000,
        framerate: int = 60,
        shareable: bool = False,
        online: bool = False,
        camera_names: Optional[List[str]] = None,
        stage: Optional[pxr.Usd.Stage] = None,
        verbose: bool = True,
    ) -> None:
        """Initializes a new USD Exporter

        Args:
            model: an MjModel instance.
            max_geom: Optional integer specifying the maximum number of geoms that
              can be rendered in the same scene. If None this will be chosen
              automatically based on the estimated maximum number of renderable
              geoms in the model.
            output_directory_name: name of root directory to store outputted frames
              and assets generated by the USD renderer.
            light_intensity: default intensity of the lights in the external renderer.
            shareable: use relative paths to assets instead of absolute paths to allow
              files to be shared across users.
            online: set to true if using USD exporter for online rendering. This value
              is set to true when rendering with Isaac Sim. If online is set to true,
              shareable must be false.
            framerate: framerate of the exported scene when rendered
            camera_names: list of fixed cameras defined in the mujoco model to render.
            stage: predefined stage to add objects in the scene to.
            verbose: decides whether to print updates.
        """

        self.model = model
        self.max_geom = max_geom
        self.output_directory_name = output_directory_name
        self.light_intensity = light_intensity
        self.framerate = framerate
        self.shareable = shareable
        self.online = online
        self.camera_names = camera_names
        self.stage = stage
        self.verbose = verbose

        self.valid_camera_names = []
        for i in range(self.model.ncam):
            name = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_CAMERA, i)
            self.valid_camera_names.append(name)

        if online and shareable:
            raise ValueError(
                f"""
Arguments online and shareable cannot both be set to true. If rendering online,
please set shareable to be false.    
"""
            )

        self.frame_count = 0  # maintains how many times we have saved the scene
        self.updates = 0

        self.geom_names = set()
        self.geom_refs = {}

        # initializing mujoco scene objects
        self._scene = _structs.MjvScene(model=model, maxgeom=max_geom)
        self._scene_option = mujoco.MjvOption()  # using default scene option

        self._initialize_usd_stage()

        # initializing output_directories
        self._initialize_output_directories()

        # loading required textures for the scene
        self._load_assets()

    @property
    def usd(self) -> str:
        """Returns the generated USD as a string."""
        return self.stage.GetRootLayer().ExportToString()

    @property
    def scene(self):
        """Returns the scene."""
        return self._scene

    @property
    def output_dir(self):
        return os.path.abspath(self.output_directory_path)

    def _initialize_usd_stage(self):
        """Initializes a USD stage to represent the mujoco scene."""
        if not self.stage:
            self.stage = Usd.Stage.CreateInMemory()
            UsdGeom.SetStageUpAxis(self.stage, UsdGeom.Tokens.z)
            self.stage.SetStartTimeCode(0)
            # add as user input
            self.stage.SetTimeCodesPerSecond(self.framerate)

            default_prim = UsdGeom.Xform.Define(self.stage, Sdf.Path("/World")).GetPrim()
            self.stage.SetDefaultPrim(default_prim)

    def _initialize_output_directories(self) -> None:
        """Initializes output directories to store frames and assets"""
        self.output_directory_path = os.path.expanduser(self.output_directory_name)
        ROBOSUITE_DEFAULT_LOGGER.info(f"Outputting USD to {self.output_directory_path}")
        if not os.path.exists(self.output_directory_path):
            os.makedirs(self.output_directory_path)

        self.frames_directory = os.path.join(self.output_directory_path, "frames")
        if not os.path.exists(self.frames_directory):
            os.makedirs(self.frames_directory)

        self.assets_directory = os.path.join(self.output_directory_path, "assets")
        if not os.path.exists(self.assets_directory):
            os.makedirs(self.assets_directory)

        if self.verbose:
            print(
                termcolor.colored(
                    "Writing output frames and assets to" f" {self.output_directory_path}",
                    "green",
                )
            )

    def _load_assets(self):
        """Load assets including textures and materials."""
        data_adr = 0
        self.texture_files = []
        for texture_id in tqdm.tqdm(range(self.model.ntex)):
            texture_height = self.model.tex_height[texture_id]
            texture_width = self.model.tex_width[texture_id]
            pixels = 3 * texture_height * texture_width
            img = im.fromarray(
                self.model.tex_rgb[data_adr : data_adr + pixels].reshape(texture_height, texture_width, 3)
            )
            img = ImageOps.flip(img)

            texture_file_name = f"texture_{texture_id}.png"
            texture_path = os.path.join(self.assets_directory, texture_file_name)
            img.save(texture_path)

            relative_path = os.path.relpath(self.assets_directory, self.frames_directory)
            img_path = os.path.join(relative_path, texture_file_name)  # relative path, 

            # self.texture_files.append(img_path)

            if self.shareable:
                self.texture_files.append(os.path.join(relative_path, texture_file_name))
            else:
                self.texture_files.append(os.path.abspath(texture_path))

            data_adr += pixels

        if self.verbose:
            print(
                termcolor.colored(
                    f"Completed writing {self.model.ntex} textures to" f" {self.assets_directory}",
                    "green",
                )
            )

    def _update_scene(
        self,
        data: mujoco.MjData,
        camera: Union[int, str, _structs.MjvCamera] = 0,
        scene_option: Optional[mujoco.MjvOption] = None,
    ) -> None:
        if not isinstance(camera, _structs.MjvCamera):
            if self.model.ncam == 0:
                raise ValueError(f"No fixed cameras defined in mujoco model.")
            camera_id = camera
            if isinstance(camera_id, str):
                camera_id = _functions.mj_name2id(self.model, _enums.mjtObj.mjOBJ_CAMERA.value, camera_id)
                if camera_id == -1:
                    raise ValueError(f'The camera "{camera}" does not exist. Valid cameras {self.valid_camera_names}')
            if camera_id == -1:
                raise ValueError("Free cameras are not supported during USD export.")
            if camera_id < 0 or camera_id >= self.model.ncam:
                raise ValueError(f"The camera id {camera_id} is out of" f" range [-1, {self.model.ncam}).")

            assert camera_id != -1

            # Render camera.
            camera = _structs.MjvCamera()
            camera.fixedcamid = camera_id
            camera.type = _enums.mjtCamera.mjCAMERA_FIXED

        scene_option = scene_option or self._scene_option
        _functions.mjv_updateScene(
            self.model,
            data,
            scene_option,
            None,
            camera,
            _enums.mjtCatBit.mjCAT_ALL.value,
            self._scene,
        )

    def update_scene(
        self,
        data: mujoco.MjData,
        camera: Union[int, str, _structs.MjvCamera] = 0,
        scene_option: Optional[mujoco.MjvOption] = None,
    ) -> None:
        """Updates the scene with latest sim data.
        Args:
            data: An instance of `MjData`
            camera: An instance of `MjvCamera`, a string, or an integer
            scene_option: A custom `MjvOption` instance to use to render
              the scene instead of the default
        """
        self.frame_count += 1

        scene_option = scene_option or self._scene_option

        # update the mujoco renderer
        self._update_scene(data, camera=camera, scene_option=scene_option)

        if self.updates == 0:
            self._initialize_usd_stage()
            self._load_lights()
            self._load_cameras()

        self._update_geoms()
        self._update_lights()
        self._update_cameras(data, camera=camera, scene_option=scene_option)

        self.updates += 1

    def _load_geom(self, geom: mujoco.MjvGeom):
        """Loads a geom into the USD scene."""
        geom_name = self._get_geom_name(geom)
        assert geom_name not in self.geom_names

        texture_file = self.texture_files[geom.texid] if geom.texid != -1 else None

        # handling meshes in our scene
        if geom.type == mujoco.mjtGeom.mjGEOM_MESH:
            usd_geom = object_module.USDMesh(
                stage=self.stage,
                model=self.model,
                geom=geom,
                obj_name=geom_name,
                dataid=self.model.geom_dataid[geom.objid],
                rgba=geom.rgba,
                texture_file=texture_file,
            )
        else:
            # handling tendons in our scene
            if geom.objtype == mujoco.mjtObj.mjOBJ_TENDON:
                mesh_config = shapes_module.mesh_config_generator(
                    name=geom_name, geom_type=geom.type, size=np.array([1.0, 1.0, 1.0]), decouple=True
                )
                usd_geom = object_module.USDTendon(
                    mesh_config=mesh_config,
                    stage=self.stage,
                    model=self.model,
                    geom=geom,
                    obj_name=geom_name,
                    rgba=geom.rgba,
                    texture_file=texture_file,
                )
            # handling primitives in our scene
            else:
                mesh_config = shapes_module.mesh_config_generator(name=geom_name, geom_type=geom.type, size=geom.size)
                usd_geom = object_module.USDPrimitiveMesh(
                    mesh_config=mesh_config,
                    stage=self.stage,
                    model=self.model,
                    geom=geom,
                    obj_name=geom_name,
                    rgba=geom.rgba,
                    texture_file=texture_file,
                )

        self.geom_names.add(geom_name)
        self.geom_refs[geom_name] = usd_geom

    def _update_geoms(self):
        """Iterate through all geoms in the scene and makes update."""
        for i in range(self.scene.ngeom):
            geom = self.scene.geoms[i]
            geom_name = self._get_geom_name(geom)

            if geom_name not in self.geom_names:
                # load a new object into USD
                self._load_geom(geom)

            if geom.objtype == mujoco.mjtObj.mjOBJ_TENDON:
                tendon_scale = geom.size
                self.geom_refs[geom_name].update(
                    pos=geom.pos,
                    mat=geom.mat,
                    scale=tendon_scale,
                    visible=geom.rgba[3] > 0,
                    frame=self.updates if not self.online else None,
                )
            else:
                self.geom_refs[geom_name].update(
                    pos=geom.pos,
                    mat=geom.mat,
                    visible=geom.rgba[3] > 0,
                    frame=self.updates if not self.online else None,
                )

    def _load_lights(self) -> None:
        # initializes an usd light object for every light in the scene
        self.usd_lights = []
        for i in range(self._scene.nlight):
            light = self._scene.lights[i]
            if not np.allclose(light.pos, [0, 0, 0]):
                self.usd_lights.append(light_module.USDSphereLight(stage=self.stage, light_name=str(i)))
            else:
                self.usd_lights.append(None)

    def _update_lights(self) -> None:
        for i in range(self._scene.nlight):
            light = self._scene.lights[i]

            if np.allclose(light.pos, [0, 0, 0]):
                continue

            if i >= len(self.usd_lights) or self.usd_lights[i] is None:
                continue

            self.usd_lights[i].update(
                pos=light.pos,
                intensity=self.light_intensity,
                color=light.diffuse,
                frame=self.updates,
            )

    def _load_cameras(self) -> None:
        self.usd_cameras = {}

        # add a primary camera for which the scene will be rendered
        self.usd_cameras[PRIMARY_CAMERA_NAME] = camera_module.USDCamera(
            stage=self.stage, camera_name=PRIMARY_CAMERA_NAME
        )
        if self.camera_names is not None:
            for camera_name in self.camera_names:
                self.usd_cameras[camera_name] = camera_module.USDCamera(stage=self.stage, camera_name=camera_name)

    def _get_camera_orientation(self) -> Tuple[_structs.MjvGLCamera, np.ndarray]:
        avg_camera = mujoco.mjv_averageCamera(self._scene.camera[0], self._scene.camera[1])

        forward = avg_camera.forward
        up = avg_camera.up
        right = np.cross(forward, up)

        R = np.eye(3)
        R[:, 0] = right
        R[:, 1] = up
        R[:, 2] = -forward

        return avg_camera, R

    def _update_cameras(
        self,
        data: mujoco.MjData,
        camera: Union[int, str, _structs.MjvCamera] = 0,
        scene_option: Optional[mujoco.MjvOption] = None,
    ) -> None:
        # first, update the primary camera given the new scene
        self._update_scene(
            data,
            camera=camera,
            scene_option=scene_option,
        )
        avg_camera, R = self._get_camera_orientation()
        self.usd_cameras[PRIMARY_CAMERA_NAME].update(cam_pos=avg_camera.pos, cam_mat=R, frame=self.updates)

        # update the names of the fixed cameras in the scene
        if self.camera_names is not None:
            for camera_name in self.camera_names:
                self._update_scene(data, camera=camera_name, scene_option=scene_option)
                avg_camera, R = self._get_camera_orientation()
                self.usd_cameras[camera_name].update(cam_pos=avg_camera.pos, cam_mat=R, frame=self.updates)

    def add_light(
        self,
        pos: List[float],
        intensity: int,
        color: List[float] = [0.3, 0.3, 0.3],
        light_type: Optional[str] = "sphere",
        light_name: Optional[str] = "light_1",
        **light_params,
    ) -> None:
        """Adds a light posthoc (i.e., light not defined in the mujoco model)

        Args:
          pos: position of the light in 3D space
          intensity: intensity of the light
          radius: radius of the originating light source
          color: color of the light
          light_type: type of light source (types include "sphere", "dome") 
          light_name: name of the light to be stored in USD
        """
        pos = np.array(pos)
        color = np.array(color)
        if light_type == "sphere":
            assert "radius" in light_params, "Please provide a radius for the sphere light."
            radius = light_params["radius"]

            new_light = light_module.USDSphereLight(stage=self.stage, light_name=light_name, radius=radius)

            new_light.update(pos=pos, intensity=intensity, color=color, frame=0)
        elif light_type == "dome":
            new_light = light_module.USDDomeLight(stage=self.stage, light_name=light_name)

            new_light.update(intensity=intensity, color=color)
        elif light_type == "rect" or light_type == "rectangle":
            assert "width" in light_params, "Please provide a width for the rect light."
            assert "height" in light_params, "Please provide a height for the rect light."
            width = light_params["width"]
            height = light_params["height"]

            new_light = light_module.USDRectLight(stage=self.stage, light_name=light_name, width=width, height=height)

            new_light.update(pos=pos, intensity=intensity, color=color, frame=0)
        elif light_type == "cylinder":
            assert "length" in light_params, "Please provide a length for the cylinder light."
            assert "radius" in light_params, "Please provide a radius for the cylinder light."
            length = light_params["length"]
            radius = light_params["radius"]

            new_light = light_module.USDCylinderLight(
                stage=self.stage, light_name=light_name, length=length, radius=radius
            )

            new_light.update(pos=pos, intensity=intensity, color=color, frame=0)
        else:
            raise ValueError(f"Light type {light_type} is not supported.")

    def add_camera(
        self,
        pos: List[float],
        rotation_xyz: List[float],
        camera_name: Optional[str] = "camera_1",
    ) -> None:
        """Adds a camera posthoc (i.e., camera not defined in the mujoco model)

        Args:
          pos: position of the camera in 3D space
          rotation_xyz: euler rotation of the camera in 3D space
          camera_name: name of the camera to be stored in USD
        """
        new_camera = camera_module.USDCamera(stage=self.stage, camera_name=camera_name)

        r = scipy.spatial.transform.Rotation.from_euler("xyz", rotation_xyz, degrees=True)
        new_camera.update(cam_pos=np.array(pos), cam_mat=r.as_matrix(), frame=0)

    def save_scene(self, filetype: str = "usd") -> None:
        """Saves the current scene as a USD trajectory
        Args:
          filetype: type of USD file to save the scene as (options include
            "usd", "usda", and "usdc")
        """
        self.stage.SetEndTimeCode(self.frame_count)

        # post-processing for visibility of geoms in scene
        for _, geom_ref in self.geom_refs.items():
            geom_ref.update_visibility(False, geom_ref.last_visible_frame + 1)

        self.stage.Export(os.path.join(self.output_directory_path, f"frames/frame_{self.frame_count}.{filetype}"))
        if self.verbose:
            print(termcolor.colored(f"Completed writing frame_{self.frame_count}.{filetype}", "green"))

    def _get_geom_name(self, geom) -> str:
        # adding id as part of name for USD file
        geom_name = mujoco.mj_id2name(self.model, geom.objtype, geom.objid)
        if not geom_name:
            geom_name = "None"
        geom_name = geom_name.replace("-", "m_")
        geom_name = geom_name.replace("+", "p_")
        geom_name += f"_id{geom.objid}"

        # adding additional naming information to differentiate
        # between geoms and tendons
        if geom.objtype == mujoco.mjtObj.mjOBJ_GEOM:
            geom_name += "_geom"
        elif geom.objtype == mujoco.mjtObj.mjOBJ_TENDON:
            geom_name += f"_tendon_segid{geom.segid}"

        return geom_name
