# Pos_Ori.py
# -*- coding: utf-8 -*-
"""
Collection of pose and coordinate transformation tool functions.

Features:
- Euler angles (deg) <-> Quaternion conversion
- Convert contact point world coordinates measured initially to board local coordinates
- Convert board local contact point coordinates to current world coordinates
- Single arm: Calculate new contact point world coordinates and gripper Euler angles based on board initial/current pose and initial gripper pose
- High-level wrapper: Get target_position and target_orientation directly from "initial world coordinate contact point and board initial/current pose"

Conventions:
- Euler angles are in "degrees", default rotation order is 'xyz' (if you use 'zyx' in simulation, please change to 'zyx' uniformly).
"""

from typing import Tuple
import numpy as np
from scipy.spatial.transform import Rotation as R


from typing import Tuple
import numpy as np
from scipy.spatial.transform import Rotation as R
def euler_degrees_to_quaternion(roll_deg: float, pitch_deg: float, yaw_deg: float, order: str = 'xyz') -> np.ndarray:
    """
    Euler angles (degrees) -> Quaternion [w, x, y, z]
    Args:
        roll_deg, pitch_deg, yaw_deg: Euler angles in degrees
        order: Euler angle order, default 'xyz'
    Returns:
        Quaternion np.ndarray(4,), format [w, x, y, z]
    """
    # 1. Convert input degrees to radians
    roll = np.radians(roll_deg)
    pitch = np.radians(pitch_deg)
    yaw = np.radians(yaw_deg)
    
    # 2. Create rotation object from Euler angles
    # Note: from_euler input angle order should match order string,
    # Your library default is 'xyz', so input is [roll, pitch, yaw]
    rotation = R.from_euler(order, [roll, pitch, yaw])
    
    # 3. Get SciPy default [x, y, z, w] format quaternion
    q_xyzw = rotation.as_quat()
    
    # 4. Rearrange to [w, x, y, z] format and return
    # q_xyzw[3] is w, q_xyzw[0] is x, q_xyzw[1] is y, q_xyzw[2] is z
    return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]])




def quaternion_to_euler_degrees(quaternion: np.ndarray, order: str = 'xyz') -> np.ndarray:
    """
    Quaternion -> Euler angles (degrees)
    Args:
        quaternion: np.ndarray(4,), format [x, y, z, w]
        order: Euler angle order, default 'xyz'
    Returns:
        Euler angles (degrees) np.ndarray(3,)
    """
    rotation = R.from_quat(np.asarray(quaternion))
    euler_deg = rotation.as_euler(order, degrees=True)
    return euler_deg


def quaternion_to_euler_degrees2(quaternion: np.ndarray, order: str = 'xyz', input_format: str = 'xyzw') -> np.ndarray:
    """
    Quaternion -> Euler angles (degrees)
    Args:
        quaternion: np.ndarray(4,), format can be [x, y, z, w] or [w, x, y, z]
        order: Euler angle order, default 'xyz'
        input_format: 'xyzw' or 'wxyz'
    Returns:
        Euler angles (degrees) np.ndarray(3,)
    """
    q = np.asarray(quaternion, dtype=float)
    if input_format == 'wxyz':
        # Rearrange to [x, y, z, w] expected by scipy
        q = np.array([q[1], q[2], q[3], q[0]])
    
    rotation = R.from_quat(q)
    euler_deg = rotation.as_euler(order, degrees=True)
    return euler_deg


def normalize_euler_degrees(euler_deg: np.ndarray) -> np.ndarray:
    """
    Normalize Euler angles (degrees) to [-180, 180] range to reduce jumping
    Args:
        euler_deg: Euler angles (degrees) np.ndarray(3,)
    Returns:
        Normalized Euler angles (degrees) np.ndarray(3,)
    """
    e = np.asarray(euler_deg, dtype=float)
    return (e + 180.0) % 360.0 - 180.0


def world_contact_to_local(
    contact_world: np.ndarray,
    initial_board_position: np.ndarray,
    initial_board_euler: np.ndarray,
    order: str = 'xyz'
) -> np.ndarray:
    """
    Convert the contact point measured in world coordinates under the "initial board pose" to board local coordinates.
    Only needs to be done once at the beginning, then reused.
    Args:
        contact_world: Contact point world coordinates at initial state np.ndarray(3,)
        initial_board_position: Initial board position t0 np.ndarray(3,)
        initial_board_euler: Initial board Euler angles (degrees) np.ndarray(3,)
        order: Euler angle order
    Returns:
        local: Contact point position in board local coordinate system np.ndarray(3,)
    """
    R_B0 = R.from_euler(order, np.radians(initial_board_euler))
    local = R_B0.as_matrix().T @ (np.asarray(contact_world) - np.asarray(initial_board_position))
    return local


def local_to_world_contact(
    local: np.ndarray,
    current_board_position: np.ndarray,
    current_board_euler: np.ndarray,
    order: str = 'xyz'
) -> np.ndarray:
    """
    Convert contact point in "board local coordinates" to "current world coordinates"
    Args:
        local: Contact point local coordinates np.ndarray(3,)
        current_board_position: Current board position tc np.ndarray(3,)
        current_board_euler: Current board Euler angles (degrees) np.ndarray(3,)
        order: Euler angle order
    Returns:
        Contact point world coordinates np.ndarray(3,)
    """
    R_Bc = R.from_euler(order, np.radians(current_board_euler))
    p_world = R_Bc.as_matrix() @ np.asarray(local) + np.asarray(current_board_position)
    return p_world


def calculate_contact_point_and_orientation(
    local: np.ndarray,
    initial_gripper_orientation_euler: np.ndarray,  # Initial gripper Euler angles [roll, pitch, yaw], unit: degrees
    current_board_position: np.ndarray,
    current_board_euler: np.ndarray,               # Current board Euler angles [roll, pitch, yaw], unit: degrees
    initial_board_euler: np.ndarray = np.array([0.0, 0.0, 0.0]),
    order: str = 'xyz',
    normalize_angles: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Single arm version:
    Input a contact point local coordinate (local) and "gripper Euler angles under initial board pose", output:
    - new_contact_world: New contact point world coordinates
    - new_gripper_euler: New gripper target Euler angles (degrees)

    Calculation principle:
    1) Position: p_world = R_Bc @ local + t_c
    2) Orientation: R_Gc = (R_Bc * R_B0^{-1}) * R_G0
       where R_B0 is initial board rotation, R_Bc is current board rotation, R_G0 is initial gripper rotation

    Args:
        local: Contact point position in board local coordinate system np.ndarray(3,)
        initial_gripper_orientation_euler: Initial gripper Euler angles (degrees) np.ndarray(3,)
        current_board_position: Current board position np.ndarray(3,)
        current_board_euler: Current board Euler angles (degrees) np.ndarray(3,)
        initial_board_euler: Initial board Euler angles (degrees) np.ndarray(3,)
        order: Euler angle order, default 'xyz'
        normalize_angles: Whether to normalize output Euler angles to [-180, 180]

    Returns:
        new_contact_world: np.ndarray(3,)
        new_gripper_euler: np.ndarray(3,), unit degrees
    """
    # Initial and current board rotation
    R_B0 = R.from_euler(order, np.radians(initial_board_euler))
    R_Bc = R.from_euler(order, np.radians(current_board_euler))

    # Board relative rotation
    R_delta = R_Bc * R_B0.inv()

    # New contact point world coordinates
    new_contact_world = R_Bc.as_matrix() @ np.asarray(local) + np.asarray(current_board_position)

    # Gripper new pose
    R_G0 = R.from_euler(order, np.radians(initial_gripper_orientation_euler))
    R_Gc = R_delta * R_G0
    new_gripper_euler = R_Gc.as_euler(order, degrees=True)

    if normalize_angles:
        new_gripper_euler = normalize_euler_degrees(new_gripper_euler)

    return new_contact_world, new_gripper_euler

# Use this
def calculate_target_from_world_contact(
    contact_world_init: np.ndarray,                  # Initial contact point world coordinates measured with cube
    initial_board_position: np.ndarray,              # Initial board position
    initial_board_euler: np.ndarray,                 # Initial board Euler angles (degrees)
    initial_gripper_orientation_euler: np.ndarray,   # Initial gripper Euler angles (degrees)
    current_board_position: np.ndarray,              # Current board position
    current_board_euler: np.ndarray,                 # Current board Euler angles (degrees)
    order: str = 'xyz',
    normalize_angles: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
    """
    High-level wrapper:
    Calculate controller target directly using "contact point in initial world coordinates" and board initial/current pose:
    - target_position: New contact point world coordinates (as gripper target position)
    - target_orientation: New gripper target Euler angles (degrees)

    Internal process:
    1) Convert initial contact point world coordinates to board local coordinates
    2) Convert local coordinates back to current world coordinates using current board pose as target_position
    3) Apply board relative rotation to initial gripper pose to get target_orientation
    """
    # 1) World -> Local (Initial)
    local = world_contact_to_local(
        contact_world=contact_world_init,
        initial_board_position=initial_board_position,
        initial_board_euler=initial_board_euler,
        order=order
    )

    # 2) + 3) Position and Orientation
    target_position, target_orientation = calculate_contact_point_and_orientation(
        local=local,
        initial_gripper_orientation_euler=initial_gripper_orientation_euler,
        current_board_position=current_board_position,
        current_board_euler=current_board_euler,
        initial_board_euler=initial_board_euler,
        order=order,
        normalize_angles=normalize_angles
    )

    return target_position, target_orientation




def apply_incremental_euler(
    current_euler_deg: np.ndarray,
    delta_euler_deg: np.ndarray,
    order: str = 'xyz',
    frame: str = 'world',           # 'world': Rotate around world axes (extrinsic/left multiply); 'local': Rotate around self axes (intrinsic/right multiply);
                                    # 'tilt' / 'table': Tilt along a certain direction in the table plane (see explanation below)
    normalize_angles: bool = True
) -> np.ndarray:
    """
    Apply an Euler angle increment (degrees) on top of current Euler angles (degrees).

    Args:
        current_euler_deg: Current Euler angles (degrees) np.ndarray(3,) [roll, pitch, yaw]
        delta_euler_deg:   Incremental Euler angles (degrees) np.ndarray(3,) [d_roll, d_pitch, d_yaw]
        order:             Euler angle order, default 'xyz'
        frame:             'world' means extrinsic rotation increment, R_new = R_inc * R_cur;
                           'local' means intrinsic rotation increment, R_new = R_cur * R_inc.
        normalize_angles:  Whether to normalize result to [-180, 180], default True

    Returns:
        new_euler_deg: Euler angles (degrees) after adding increment np.ndarray(3,)

    Special mode explanation (frame='tilt' or 'table'):
        This mode is used when "gripper is initially perpendicular to table, then tilts up along a certain direction in the table",
        avoiding issues like left/right hand not parallel caused by direct Euler angle addition.

        - Assume table normal is +Z direction.
        - delta_euler_deg[0]: Tilt angle (degrees), i.e., how many degrees rotated from table normal.
        - delta_euler_deg[1]: Azimuth angle of tilt plane (degrees), in table XY plane, 0° is World X axis, counter-clockwise is positive.
          i.e., angle between "projection of tilt direction on table" and X axis.
        - delta_euler_deg[2]: Currently unused, can be 0.
    """
    # Construct current pose rotation
    R_cur = R.from_euler(order, np.asarray(current_euler_deg, dtype=float), degrees=True)

    f = str(frame).lower()
    if f in ('world', 'global', 'extrinsic', 'outer', 'local', 'body', 'intrinsic', 'inner'):
        # Traditional Euler increment mode: keep compatibility
        R_inc = R.from_euler(order, np.asarray(delta_euler_deg, dtype=float), degrees=True)

        if f in ('world', 'global', 'extrinsic', 'outer'):
            # Extrinsic: apply increment first, then current -> left multiply
            R_new = R_inc * R_cur
        else:
            # Intrinsic: current first, then increment -> right multiply
            R_new = R_cur * R_inc

    elif f in ('tilt', 'table'):
        # Table tilt mode:
        # - Assume table normal is world +Z
        # - Tilt gradually towards that direction from table normal according to given horizontal direction and angle.

        delta = np.asarray(delta_euler_deg, dtype=float).reshape(-1)
        if delta.size < 2:
            raise ValueError("In frame='tilt' mode, delta_euler_deg needs at least 2 components: [tilt_angle_deg, inplane_dir_deg]")

        tilt_angle_deg = float(delta[0])   # Tilt angle
        inplane_dir_deg = float(delta[1])  # Azimuth angle on table

        # Table normal (assumed to be +Z)
        n = np.array([0.0, 0.0, 1.0], dtype=float)

        # "Horizontal direction" unit vector d of tilt plane (in XY plane)
        theta = np.deg2rad(inplane_dir_deg)
        d = np.array([np.cos(theta), np.sin(theta), 0.0], dtype=float)
        dn = np.linalg.norm(d)
        if dn < 1e-6:
            # Direction close to zero, degenerate to X axis
            d = np.array([1.0, 0.0, 0.0], dtype=float)
        else:
            d /= dn

        # Rotation axis is the normal of plane containing n and d: axis = d × n
        # This way, rotating will "tilt" normal n towards d.
        axis = np.cross(d, n)
        an = np.linalg.norm(axis)
        if an < 1e-6:
            # If d and n are nearly collinear, plane degenerates, fallback to Y axis as rotation axis
            axis = np.array([0.0, 1.0, 0.0], dtype=float)
        else:
            axis /= an

        # Construct incremental rotation from axis-angle (left multiply in world coordinates)
        R_inc = R.from_rotvec(np.deg2rad(tilt_angle_deg) * axis)
        R_new = R_inc * R_cur

    else:
        raise ValueError(
            f"frame must be 'world' / 'local' or 'tilt'/'table', received: {frame}"
        )

    new_euler_deg = R_new.as_euler(order, degrees=True)

    if normalize_angles:
        new_euler_deg = normalize_euler_degrees(new_euler_deg)

    return new_euler_deg


def _quat_angle_error_deg(q_cur, q_goal):
    """
    Calculate minimum rotation angle (degrees) from current quaternion to target quaternion
    q_* format: [x, y, z, w]
    """
    R_cur = R.from_quat(np.asarray(q_cur, dtype=float))
    R_goal = R.from_quat(np.asarray(q_goal, dtype=float))
    R_err = R_goal * R_cur.inv()
    angle_rad = np.linalg.norm(R_err.as_rotvec())
    return np.degrees(angle_rad)




def get_xy_move_from_angle(angle_deg: float, distance: float) -> Tuple[float, float]:
    """
    Calculate X and Y axis components on 2D plane based on an angle (degrees) and distance.

    This follows standard mathematical convention:
    - 0 degrees is along positive X axis.
    - Angle increases counter-clockwise.

    Args:
        angle_deg: Angle in degrees (float).
        distance: Total movement distance, i.e., hypotenuse length (float).

    Returns:
        A tuple containing (delta_x, delta_y) (Tuple[float, float]).
        - delta_x: Movement component on X axis.
        - delta_y: Movement component on Y axis.
    """
    # 1. Convert input degrees to radians, as np.cos/sin need radians
    angle_rad = np.radians(angle_deg)
    
    # 2. Calculate x and y components using trigonometric functions
    delta_x = distance * np.cos(angle_rad)
    delta_y = distance * np.sin(angle_rad)
    
    return (delta_x, delta_y)

def get_y_move_from_angle(angle_deg: float, x_move: float) -> float:
    """
    Calculate movement along y axis based on angle
    Args:
        angle_deg: Rotation angle (degrees)
        x_move: Distance moved along x axis (meters)
    Returns:
        y_move: Distance to move along y axis (meters)
    """
    angle_rad = np.deg2rad(angle_deg)  # Angle to radians
    y_move = np.tan(angle_rad) * x_move
    return y_move

__all__ = [
    'euler_degrees_to_quaternion',
    'quaternion_to_euler_degrees',
    'quaternion_to_euler_degrees2',
    'normalize_euler_degrees',
    'world_contact_to_local',
    'local_to_world_contact',
    'calculate_contact_point_and_orientation',
    'calculate_target_from_world_contact',
    'apply_incremental_euler',
    '_quat_angle_error_deg',
    'get_xy_move_from_angle',
    'get_y_move_from_angle'
]
