# Copyright (c) 2021-2023, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""Factory: Class for gears task.

Inherits gears environment class and abstract task class (not inforced). Can be executed with 
python train.py task=FactoryTaskGears

Only the environment is provided; training a successful RL policy is an open research problem left to the user.
"""

import hydra
import math
import omegaconf
import os
import torch

from isaacgym import gymapi, gymtorch
from isaacgymenvs.tasks.factory.factory_env_gears import FactoryEnvGears
from isaacgymenvs.tasks.factory.factory_schema_class_task import FactoryABCTask
from isaacgymenvs.tasks.factory.factory_schema_config_task import FactorySchemaConfigTask


class FactoryTaskGears(FactoryEnvGears, FactoryABCTask):

    def __init__(self, cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render):
        """Initialize instance variables. Initialize task superclass."""

        super().__init__(cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render)

        self.cfg = cfg
        self._get_task_yaml_params()
        if self.viewer != None:
            self._set_viewer_params()
        if self.cfg_base.mode.export_scene:
            self.export_scene(label='factory_task_gears')

    def _get_task_yaml_params(self):
        """Initialize instance variables from YAML files."""

        cs = hydra.core.config_store.ConfigStore.instance()
        cs.store(name='factory_schema_config_task', node=FactorySchemaConfigTask)

        self.cfg_task = omegaconf.OmegaConf.create(self.cfg)
        self.max_episode_length = self.cfg_task.rl.max_episode_length  # required instance var for VecTask

        asset_info_path = '../../assets/factory/yaml/factory_asset_info_gears.yaml'  # relative to Gym's Hydra search path (cfg dir)
        self.asset_info_gears = hydra.compose(config_name=asset_info_path)
        self.asset_info_gears = self.asset_info_gears['']['']['']['']['']['']['assets']['factory']['yaml']  # strip superfluous nesting

        ppo_path = 'train/FactoryTaskGearsPPO.yaml'  # relative to Gym's Hydra search path (cfg dir)
        self.cfg_ppo = hydra.compose(config_name=ppo_path)
        self.cfg_ppo = self.cfg_ppo['train']  # strip superfluous nesting

    def _acquire_task_tensors(self):
        """Acquire tensors."""
        pass

    def _refresh_task_tensors(self):
        """Refresh tensors."""
        pass

    def pre_physics_step(self, actions):
        """Reset environments. Apply actions from policy as position/rotation targets, force/torque targets, and/or PD gains."""

        env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
        if len(env_ids) > 0:
            self.reset_idx(env_ids)

        self._actions = actions.clone().to(self.device)  # shape = (num_envs, num_actions); values = [-1, 1]

    def post_physics_step(self):
        """Step buffers. Refresh tensors. Compute observations and reward."""

        self.progress_buf[:] += 1

        self.refresh_base_tensors()
        self.refresh_env_tensors()
        self._refresh_task_tensors()
        self.compute_observations()
        self.compute_reward()

    def compute_observations(self):
        """Compute observations."""

        return self.obs_buf  # shape = (num_envs, num_observations)

    def compute_reward(self):
        """Detect successes and failures. Update reward and reset buffers."""

        self._update_rew_buf()
        self._update_reset_buf()

    def _update_rew_buf(self):
        """Compute reward at current timestep."""
        pass

    def _update_reset_buf(self):
        """Assign environments for reset if successful or failed."""
        pass

    def reset_idx(self, env_ids):
        """Reset specified environments."""

        self._reset_franka(env_ids)
        self._reset_object(env_ids)

        self.reset_buf[env_ids] = 0
        self.progress_buf[env_ids] = 0

    def _reset_franka(self, env_ids):
        """Reset DOF states and DOF targets of Franka."""

        # shape of dof_pos = (num_envs, num_dofs)
        # shape of dof_vel = (num_envs, num_dofs)

        # Initialize Franka to middle of joint limits, plus joint noise
        franka_dof_props = self.gym.get_actor_dof_properties(self.env_ptrs[0],
                                                             self.franka_handles[0])  # same across all envs
        lower_lims = franka_dof_props['lower']
        upper_lims = franka_dof_props['upper']
        self.dof_pos[:, 0:self.franka_num_dofs] = torch.tensor((lower_lims + upper_lims) * 0.5, device=self.device) \
                                                  + (torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.joint_noise * math.pi / 180

        self.dof_vel[env_ids, 0:self.franka_num_dofs] = 0.0

        franka_actor_ids_sim_int32 = self.franka_actor_ids_sim.to(dtype=torch.int32, device=self.device)[env_ids]
        self.gym.set_dof_state_tensor_indexed(self.sim,
                                              gymtorch.unwrap_tensor(self.dof_state),
                                              gymtorch.unwrap_tensor(franka_actor_ids_sim_int32),
                                              len(franka_actor_ids_sim_int32))

        self.ctrl_target_dof_pos[env_ids, 0:self.franka_num_dofs] = self.dof_pos[env_ids, 0:self.franka_num_dofs]
        self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.ctrl_target_dof_pos))

    def _reset_object(self, env_ids):
        """Reset root state of gears."""

        # shape of root_pos = (num_envs, num_actors, 3)
        # shape of root_quat = (num_envs, num_actors, 4)
        # shape of root_linvel = (num_envs, num_actors, 3)
        # shape of root_angvel = (num_envs, num_actors, 3)

        if self.cfg_task.randomize.initial_state == 'random':
            self.root_pos[env_ids, self.gear_small_actor_id_env] = \
                torch.cat(((torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.gears_noise_xy,
                           - self.cfg_task.randomize.gears_bias_y + (torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.gears_noise_xy,
                           torch.ones((self.num_envs, 1), device=self.device) * (self.cfg_base.env.table_height + self.cfg_task.randomize.gears_bias_z)
                           ), dim=1)
            self.root_pos[env_ids, self.gear_medium_actor_id_env] = \
                torch.cat(((torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.gears_noise_xy,
                           self.cfg_task.randomize.gears_bias_y + (torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.gears_noise_xy,
                           torch.ones((self.num_envs, 1), device=self.device) * (self.cfg_base.env.table_height + self.cfg_task.randomize.gears_bias_z)
                           ), dim=1)
            self.root_pos[env_ids, self.gear_large_actor_id_env] = \
                torch.cat(((torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.gears_noise_xy,
                           - self.cfg_task.randomize.gears_bias_y + (torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.gears_noise_xy,
                           torch.ones((self.num_envs, 1), device=self.device) * (self.cfg_base.env.table_height + self.cfg_task.randomize.gears_bias_z)), dim=1)
        elif self.cfg_task.randomize.initial_state == 'goal':
            self.root_pos[env_ids, self.gear_small_actor_id_env] = torch.tensor(
                [0.0, 0.0, self.cfg_base.env.table_height], device=self.device)
            self.root_pos[env_ids, self.gear_medium_actor_id_env] = torch.tensor(
                [0.0, 0.0, self.cfg_base.env.table_height], device=self.device)
            self.root_pos[env_ids, self.gear_large_actor_id_env] = torch.tensor(
                [0.0, 0.0, self.cfg_base.env.table_height], device=self.device)

        self.root_linvel[env_ids, self.gear_small_actor_id_env] = 0.0
        self.root_angvel[env_ids, self.gear_small_actor_id_env] = 0.0
        self.root_linvel[env_ids, self.gear_medium_actor_id_env] = 0.0
        self.root_angvel[env_ids, self.gear_medium_actor_id_env] = 0.0
        self.root_linvel[env_ids, self.gear_large_actor_id_env] = 0.0
        self.root_angvel[env_ids, self.gear_large_actor_id_env] = 0.0

        gear_small_actor_ids_sim_int32 = self.gear_small_actor_ids_sim.to(dtype=torch.int32, device=self.device)
        gear_medium_actor_ids_sim_int32 = self.gear_medium_actor_ids_sim.to(dtype=torch.int32, device=self.device)
        gear_large_actor_ids_sim_int32 = self.gear_large_actor_ids_sim.to(dtype=torch.int32, device=self.device)
        gears_actor_ids_sim_int32 = torch.cat((gear_small_actor_ids_sim_int32[env_ids],
                                               gear_medium_actor_ids_sim_int32[env_ids],
                                               gear_large_actor_ids_sim_int32[env_ids]))

        self.gym.set_actor_root_state_tensor_indexed(self.sim,
                                                     gymtorch.unwrap_tensor(self.root_state),
                                                     gymtorch.unwrap_tensor(gears_actor_ids_sim_int32),
                                                     len(gear_small_actor_ids_sim_int32[env_ids]) +
                                                     len(gear_medium_actor_ids_sim_int32[env_ids]) +
                                                     len(gear_large_actor_ids_sim_int32[env_ids])
                                                     )


    def _reset_buffers(self, env_ids):
        """Reset buffers. """

        self.reset_buf[env_ids] = 0
        self.progress_buf[env_ids] = 0

    def _set_viewer_params(self):
        """Set viewer parameters."""

        cam_pos = gymapi.Vec3(-1.0, -1.0, 1.0)
        cam_target = gymapi.Vec3(0.0, 0.0, 0.5)
        self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)
