# import time
# import os
import numpy as np
# from dm_control import mjcf
# import mujoco.viewer
# import gymnasium as gym
from gymnasium import spaces
from manipulator_mujoco.envs import UR5eEnv
from manipulator_mujoco.utils.transform_utils import (
    mat2quat,
    quat2mat,
    axisangle2quat,
)

def axisangle2mat(axis_angle):
    return quat2mat(axisangle2quat(axis_angle))

def quat_angle_distance(q1, q2):
    """Compute the angle distance between two quaternions."""
    if not np.isclose(np.linalg.norm(q1), 1.0) or not np.isclose(np.linalg.norm(q2), 1.0):
        print("Quaternion not normalized", np.linalg.norm(q1), np.linalg.norm(q2))
    assert np.isclose(np.linalg.norm(q1), 1.0), "Quaternion not normalized"
    assert np.isclose(np.linalg.norm(q2), 1.0), "Quaternion not normalized"
    dot_product = np.abs(np.dot(q1, q2))
    assert 0.0 <= dot_product <= 1.0, "Dot product out of range"
    angle = 2.0 * np.arccos(dot_product)
    return angle

def goal_distance_pose(goal_a, goal_b, pos_weight=1.0, rot_weight=0.19098621461):
    """
    Compute the distance between two SE(3) poses.
    goal_a, goal_b: [pos(3), quat(4)]
    """
    assert goal_a.shape == goal_b.shape
    assert goal_a.shape[-1] == 7  # [x, y, z, qx, qy, qz, qw]
    
    # Position distance
    pos_dist = np.linalg.norm(goal_a[..., :3] - goal_b[..., :3], axis=-1)
    
    # Orientation distance (quaternion angle)
    angle_dist = quat_angle_distance(goal_a[..., 3:], goal_b[..., 3:])
    
    # Combined distance
    total_dist = pos_weight * pos_dist + rot_weight * angle_dist
    
    return total_dist

class UR5eReachEnv(UR5eEnv):
    """
    UR5e reach environment with relative actions in end-effector frame.
    Action space: 6-dim (x, y, z translation + axis-angle rotation) in EE frame
    Observation space: Goal-aware observation with achieved_goal and desired_goal
    
    Supports both dense and sparse reward:
    - Dense: negative distance to goal (continuous feedback)
    - Sparse: +1 for success, 0 otherwise (only at goal achievement)
    """
    
    def __init__(self, render_mode=None, n_substeps=20, distance_threshold=0.05,
                 rotation_threshold=0.261799, max_pos_action=0.05, max_rot_action=0.261799,
                 target_range=0.15, pos_weight=1.0, rot_weight=0.19098621461, reward_type='dense'):
        """
        Initialize UR5e reach environment.
        
        Args:
            render_mode (str): Rendering mode ('human', 'rgb_array', or None)
            n_substeps (int): Number of physics substeps per environment step (default: 20 -> control freq 25Hz)
            distance_threshold (float): Position threshold for success in meters (default: 0.05)
            rotation_threshold (float): Rotation threshold for success in radians (default: 0.261799)
            max_pos_action (float): Maximum position action magnitude per step in meters (default: 0.05)
            max_rot_action (float): Maximum rotation action magnitude per step in radians (default: 0.261799 = 15deg)
            target_range (float): Range for goal sampling in meters (default: 0.15)
            pos_weight (float): Weight for position in distance calculation (default: 1.0)
            rot_weight (float): Weight for rotation in distance calculation (default: 0.19098621461 = 0.05 / 0.261799(15deg))
            reward_type (str): Type of reward ('dense' or 'sparse', default: 'dense')
        """
        # Initialize parent class
        super().__init__(render_mode)
        
        # Environment parameters
        self.n_substeps = n_substeps
        self.distance_threshold = distance_threshold # Distance threshold in meters (default 5cm)
        self.rotation_threshold = rotation_threshold  # Rotation threshold in radians (default 15 deg)
        self.max_pos_action = max_pos_action  # Maximum position action magnitude per step (default 5cm)
        self.max_rot_action = max_rot_action  # Maximum rotation action magnitude per step (default 15 deg)
        self.target_range = target_range  # Range for goal sampling
        self.pos_weight = pos_weight  # Weight for position in distance calculation
        self.rot_weight = rot_weight  # Weight for rotation in distance calculation
        self.reward_type = reward_type  # 'dense' or 'sparse'
        
        # Validate reward_type
        assert reward_type in ['dense', 'sparse'], f"reward_type must be 'dense' or 'sparse', got {reward_type}"
        
        # Action space: 6-dim relative actions in EE frame (position + axis-angle)
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(6,), dtype=np.float32
        )
        
        # Goal for the reaching task: SE(3) pose [pos(3), quat(4)]
        self.goal = np.zeros(7)  # [x, y, z, qx, qy, qz, qw]
        
        # Observation space: goal-aware observations
        obs = self._get_obs()
        self.observation_space = spaces.Dict({
            'observation': spaces.Box(
                low=-np.inf, high=np.inf, 
                shape=obs['observation'].shape, dtype=np.float64
            ),
            'achieved_goal': spaces.Box(
                low=-np.inf, high=np.inf, 
                shape=obs['achieved_goal'].shape, dtype=np.float64
            ),
            'desired_goal': spaces.Box(
                low=-np.inf, high=np.inf, 
                shape=obs['desired_goal'].shape, dtype=np.float64
            ),
        })
        
    def _get_obs(self):
        """Get goal-aware observations."""
        # Get current end-effector pose
        ee_pos = self._physics.bind(self._arm.eef_site).xpos.copy()
        ee_quat = mat2quat(self._physics.bind(self._arm.eef_site).xmat.reshape(3, 3))
        
        # Get end-effector velocity using Jacobian
        ee_vel = self._get_ee_velocity()
        
        # Achieved goal is current EE pose (position + orientation)
        achieved_goal = np.concatenate([ee_pos, ee_quat])
        
        # Construct observation vector
        obs = np.concatenate([
            ee_pos,           # EE position (3)
            ee_quat,          # EE orientation (4) 
            ee_vel,           # EE velocities (6)
        ])
        
        return {
            'observation': obs.copy(),
            'achieved_goal': achieved_goal.copy(),
            'desired_goal': self.goal.copy(),
        }
    
    def _get_ee_velocity(self):
        """Compute end-effector velocity using Jacobian and joint velocities."""
        # Import mujoco and utilities
        import mujoco
        
        # Get site ID for end-effector
        site_id = self._physics.bind(self._arm.eef_site).element_id
        
        # Get translational Jacobian (3 x nv)
        jacp = np.zeros((3, self._physics.model.nv))
        mujoco.mj_jacSite(self._physics.model.ptr, self._physics.data.ptr, jacp, None, site_id)
        
        # Get rotational Jacobian (3 x nv)  
        jacr = np.zeros((3, self._physics.model.nv))
        mujoco.mj_jacSite(self._physics.model.ptr, self._physics.data.ptr, None, jacr, site_id)
        
        # Combine translational and rotational Jacobians
        jac_full = np.vstack([jacp, jacr])  # 6 x nv
        
        # Get joint DOF indices for controlled joints
        joint_dof_ids = self._physics.bind(self._arm.joints).dofadr
        
        # Extract Jacobian for controlled joints only
        jac_controlled = jac_full[:, joint_dof_ids]  # 6 x 6
        
        # Get joint velocities
        joint_vel = self._physics.bind(self._arm.joints).qvel.copy()
        
        # Compute end-effector velocity: v = J * q_dot
        ee_vel = jac_controlled @ joint_vel
        
        # Scale by timestep for consistency with Fetch environment  
        dt = self.n_substeps * self._physics.model.opt.timestep
        ee_vel *= dt
        
        return ee_vel
    
    def _get_info(self):
        """Get info dictionary."""
        obs = self._get_obs()
        is_success = self._is_success(obs['achieved_goal'], obs['desired_goal'])
        total_distance = goal_distance_pose(obs['achieved_goal'], obs['desired_goal'], 
                                          self.pos_weight, self.rot_weight)
        pos_distance = np.linalg.norm(obs['achieved_goal'][:3] - obs['desired_goal'][:3])
        
        # Quaternion angle distance
        q1 = obs['achieved_goal'][3:]
        q2 = obs['desired_goal'][3:]
        rot_distance = quat_angle_distance(q1, q2)
        
        return {
            'is_success': float(is_success),
            'total_distance': float(total_distance),
            'pos_distance': float(pos_distance),
            'rot_distance': float(rot_distance)
        }
    
    def _is_success(self, achieved_goal, desired_goal):
        """Check if the goal has been achieved (both position and orientation)."""
        # Position distance
        pos_distance = np.linalg.norm(achieved_goal[:3] - desired_goal[:3])
        
        # Orientation distance (quaternion angle)
        q1 = achieved_goal[3:]
        q2 = desired_goal[3:]
        rot_distance = quat_angle_distance(q1, q2)
        
        pos_success = pos_distance < self.distance_threshold
        rot_success = rot_distance < self.rotation_threshold
        
        return pos_success and rot_success
    
    def _sample_goal(self):
        """Sample a random SE(3) goal pose within reach."""
        # Sample goal position around initial EE position with some offset
        initial_ee_pos = np.array([4.92026489e-01, 1.34014449e-01, 4.37990367e-01])
        goal_pos = initial_ee_pos + self.np_random.uniform(
            -self.target_range, self.target_range, size=3
        )
        
        # Sample random orientation (quaternion)
        # Generate three uniform random numbers in [0, 1]
        u1, u2, u3 = self.np_random.uniform(0, 1, size=3)

        # This formula converts the three random numbers into a uniform random quaternion
        # This correctly weights the angles to ensure uniform sampling over SO(3)
        goal_quat = np.array([
            np.sqrt(1 - u1) * np.sin(2 * np.pi * u2),  # x
            np.sqrt(1 - u1) * np.cos(2 * np.pi * u2),  # y
            np.sqrt(u1) * np.sin(2 * np.pi * u3),      # z
            np.sqrt(u1) * np.cos(2 * np.pi * u3)       # w (scalar part)
        ])
        assert np.isclose(np.linalg.norm(goal_quat), 1.0), "Quaternion not normalized"
        
        # Combine position and orientation
        goal_pose = np.concatenate([goal_pos, goal_quat])
        return goal_pose
    
    def _set_action(self, action):
        """Convert relative action in EE frame to absolute target pose for OSC controller."""
        # Clip action to valid range
        action = np.clip(action, -1.0, 1.0)
        
        # Get current EE pose
        current_ee_pos = self._physics.bind(self._arm.eef_site).xpos.copy()
        current_ee_mat = self._physics.bind(self._arm.eef_site).xmat.reshape(3, 3).copy()
        current_ee_quat = mat2quat(current_ee_mat)
        
        # Extract position and rotation actions with separate scaling
        pos_action = action[:3] * self.max_pos_action  # Scale to max position action
        rot_action = action[3:6] * self.max_rot_action  # Scale to max rotation action
        
        # Apply position action in EE frame
        # Transform relative position to world frame
        pos_delta_world = current_ee_mat @ pos_action
        target_pos = current_ee_pos + pos_delta_world
        
        # Apply rotation action in EE frame
        if np.linalg.norm(rot_action) > 1e-6:
            # Convert axis-angle to rotation matrix
            rot_delta_mat = axisangle2mat(rot_action)
            # Apply rotation in EE frame
            target_ee_mat = current_ee_mat @ rot_delta_mat
            target_quat = mat2quat(target_ee_mat)
        else:
            target_quat = current_ee_quat
            
        # Create target pose (7D: position + quaternion)
        target_pose = np.concatenate([target_pos, target_quat])
        
        return target_pose
    
    def step(self, action):
        """Execute one step with the given action."""
        # Convert relative action to absolute target pose
        target_pose = self._set_action(action)
        
        # Apply the same action for n_substeps (like Fetch environment)
        for _ in range(self.n_substeps):
            # Run OSC controller with the target pose directly
            self._controller.run(target_pose)
            
            # Step physics
            self._physics.step()
        
        # Render if needed
        if self._render_mode == "human":
            self._render_frame()
        
        # Get observations and info
        obs = self._get_obs()
        info = self._get_info()
        
        # Compute reward
        reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], info)
        
        # Check termination (reach environments typically don't terminate)
        terminated = False
        truncated = False
        
        return obs, reward, terminated, truncated, info
    
    def compute_reward(self, achieved_goal, desired_goal, info):
        """Compute reward based on SE(3) pose distance to goal."""
        if self.reward_type == 'sparse':
            # Sparse reward: +1 for success, 0 otherwise
            is_success = self._is_success(achieved_goal, desired_goal)
            return float(is_success)
        else:
            # Dense reward: negative distance
            distance = goal_distance_pose(achieved_goal, desired_goal,
                                          self.pos_weight, self.rot_weight)
            return -distance.astype(np.float32)
    
    def reset(self, seed=None, options=None):
        """Reset the environment."""
        # Call parent's parent reset (skip UR5eEnv's reset)
        super(UR5eEnv, self).reset(seed=seed)
        
        # Reset physics
        with self._physics.reset_context():
            # Set arm to initial position
            self._physics.bind(self._arm.joints).qpos = [
                0.0, -1.5707, 1.5707, -1.5707, -1.5707, 0.0
            ]
        
        # Sample new goal
        self.goal = self._sample_goal()
        
        # Get initial observation
        obs = self._get_obs()
        info = self._get_info()
        
        return obs, info
        
    def render(self):
        """Renders the current frame and returns it as an RGB array if the render mode is set to "rgb_array"."""
        if self._render_mode == "rgb_array":
            return self._render_frame()
        
    def _render_frame(self):
        """Renders the current frame and updates the viewer."""
        # Update goal visualization
        if hasattr(self, '_target') and hasattr(self, 'goal'):
            # Use the target mocap for goal visualization with goal orientation
            self._target.set_mocap_pose(
                self._physics, 
                position=self.goal[:3], 
                quaternion=self.goal[3:]
            )

        if self._render_mode == 'human':
            super()._render_frame()
        else:  # rgb_array
            return super()._render_frame()