import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment

# Import ALL predefined skills (move, pick, place, rotate, pull, …)
from skill_code import move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def _build_quat_rotated_90_deg(start_quat, axis='x'):
    """Return a quaternion that is rotated +90 deg around the given axis
    relative to `start_quat`.

    Args:
        start_quat (np.ndarray): Current quaternion, xyzw format.
        axis (str): 'x', 'y', or 'z'.

    Returns:
        np.ndarray: Target quaternion, xyzw.
    """
    rot_90 = R.from_euler(axis, 90, degrees=True)
    target_quat = (rot_90 * R.from_quat(start_quat)).as_quat()
    return target_quat


def _deduce_pull_params(anchor_pos, joint_pos):
    """Given anchor & joint positions of a drawer, compute a plausible pull
    direction string and distance.

    The function picks the dominant axis of the vector (anchor – joint) and
    returns (axis_string, distance).

    Returns:
        (str, float)
    """
    vec = anchor_pos - joint_pos
    ax_idx = np.argmax(np.abs(vec))
    axes = ['x', 'y', 'z']
    axis = axes[ax_idx]
    sign = '' if vec[ax_idx] >= 0 else '-'
    axis_string = f'{sign}{axis}'
    distance = np.linalg.norm(vec)
    return axis_string, distance


def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # Retrieve all relevant object positions once the simulation starts
        # ------------------------------------------------------------------
        positions = get_object_positions()
        # Mandatory keys we expect to exist (assert for easier debugging)
        required_keys = [
            'bottom_side_pos', 'bottom_anchor_pos', 'bottom_joint_pos',
            'tomato1', 'tomato2', 'plate'
        ]
        for k in required_keys:
            if k not in positions:
                raise KeyError(
                    f"[run_skeleton_task] Missing key '{k}' in positions dict."
                )

        # ------------------------------------------------------------------
        # STEP-1  : rotate gripper from zero_deg → ninety_deg
        # ------------------------------------------------------------------
        obs_initial = task.get_observation()
        start_quat = obs_initial.gripper_pose[3:7]
        target_quat = _build_quat_rotated_90_deg(start_quat, axis='x')
        print("[Plan-Step-1] rotate(gripper, zero_deg → ninety_deg)")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Task] Terminated during rotate.")
            return

        # ------------------------------------------------------------------
        # STEP-2  : move-to-side  (nowhere-pos → bottom_side_pos)
        # ------------------------------------------------------------------
        target_pos = positions['bottom_side_pos']
        print("[Plan-Step-2] move-to-side → bottom_side_pos:", target_pos)
        obs, reward, done = move(env, task, target_pos)
        if done:
            print("[Task] Terminated during move-to-side.")
            return

        # ------------------------------------------------------------------
        # STEP-3  : move-to-anchor  (side → anchor)
        # ------------------------------------------------------------------
        target_pos = positions['bottom_anchor_pos']
        print("[Plan-Step-3] move-to-anchor → bottom_anchor_pos:", target_pos)
        obs, reward, done = move(env, task, target_pos)
        if done:
            print("[Task] Terminated during move-to-anchor.")
            return

        # ------------------------------------------------------------------
        # STEP-4  : pick-drawer (close gripper on handle)
        # ------------------------------------------------------------------
        print("[Plan-Step-4] pick-drawer @ bottom_anchor_pos")
        obs, reward, done = pick(env, task, target_pos)
        if done:
            print("[Task] Terminated during pick-drawer.")
            return

        # ------------------------------------------------------------------
        # STEP-5  : pull drawer open
        # ------------------------------------------------------------------
        anchor_pos = positions['bottom_anchor_pos']
        joint_pos = positions['bottom_joint_pos']
        pull_axis, pull_dist = _deduce_pull_params(anchor_pos, joint_pos)
        # If the computed distance is unreasonably small/large, fall back
        pull_dist = np.clip(pull_dist, 0.05, 0.25)
        print(f"[Plan-Step-5] pull drawer – axis={pull_axis}, dist={pull_dist:.3f}")
        obs, reward, done = pull(env, task, pull_dist, pull_axis=pull_axis)
        if done:
            print("[Task] Terminated during pull.")
            return

        # ----------------------------------------------------------------------
        # Refresh positions: the drawer has moved, tomatoes may be slightly moved
        # ----------------------------------------------------------------------
        positions = get_object_positions()

        # ######################################################################
        #      Handle ALL tomatoes:   tomato1  +  tomato2  →  plate
        # ######################################################################
        for tomato in ['tomato1', 'tomato2']:
            tomato_pos = positions[tomato]
            plate_pos = positions['plate']

            # STEP-6/8 : pick tomato
            print(f"[Plan] pick({tomato}) @ table:", tomato_pos)
            obs, reward, done = pick(env, task, tomato_pos)
            if done:
                print(f"[Task] Terminated during pick of {tomato}.")
                return

            # STEP-7/9 : place tomato on plate
            print(f"[Plan] place({tomato}) → plate:", plate_pos)
            obs, reward, done = place(env, task, plate_pos)
            if done:
                print(f"[Task] Terminated during place of {tomato}.")
                return

        print("[Task] Finished all plan steps successfully!")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()