import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape                         # kept from skeleton imports
from pyrep.objects.proximity_sensor import ProximitySensor    # kept from skeleton imports
import traceback

from env import setup_environment, shutdown_environment
from skill_code import *                                      # predefined skills: rotate, move, pick, pull, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
#  Helper utilities
# ---------------------------------------------------------------------------

def _safe_position(pos_dict, key):
    """Return the xyz position for an object name (as np.float32[3])."""
    if key not in pos_dict:
        raise KeyError(f"[Task] '{key}' not found in object_positions()")
    return np.asarray(pos_dict[key], dtype=np.float32)


def _infer_pull_params(anchor_xyz, joint_xyz):
    """
    Given the anchor position (handle) and joint position (drawer slide axis),
    infer the dominant axis and distance for a straight-line pull.
    """
    diff = anchor_xyz - joint_xyz
    axis_idx = int(np.argmax(np.abs(diff)))   # 0:x, 1:y, 2:z
    axis_names = ['x', 'y', 'z']
    axis_name = axis_names[axis_idx]
    distance = diff[axis_idx]

    # Make distance always positive; encode direction in axis_name (‘x’ vs ‘-x’ …)
    if distance < 0.0:
        axis_name = f'-{axis_name}'
        distance = -distance

    return axis_name, float(distance)


def _rotate_target_quat_90_deg_from_current(task):
    """
    Construct a quaternion representing a +90° rotation about the
    world Z-axis relative to the current end-effector orientation.
    """
    cur_quat = task.get_observation().gripper_pose[3:7]
    cur_rot = R.from_quat(cur_quat)
    delta_rot = R.from_euler('z', 90, degrees=True)
    tgt_rot = delta_rot * cur_rot
    return tgt_rot.as_quat()


# ---------------------------------------------------------------------------
#  Main routine – executes oracle plan from the specification
# ---------------------------------------------------------------------------

def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # ---------------------------------------------------------------
        #  Environment reset & optional video initialisation
        # ---------------------------------------------------------------
        _, obs = task.reset()

        init_video_writers(obs)                                       # start video capture
        task.step = recording_step(task.step)                         # wrap for video
        task.get_observation = recording_get_observation(task.get_observation)

        # ---------------------------------------------------------------
        #  Gather all object positions required for the plan
        # ---------------------------------------------------------------
        positions = get_object_positions()

        bottom_side_pos   = _safe_position(positions, 'bottom_side_pos')
        bottom_anchor_pos = _safe_position(positions, 'bottom_anchor_pos')
        bottom_joint_pos  = _safe_position(positions, 'bottom_joint_pos')
        rubbish_pos       = _safe_position(positions, 'rubbish')
        bin_pos           = _safe_position(positions, 'bin')

        # ---------------------------------------------------------------
        #  Oracle Plan – step-by-step execution
        # ---------------------------------------------------------------

        # Step-1 : rotate gripper to 90° about world-Z
        try:
            target_quat = _rotate_target_quat_90_deg_from_current(task)
            obs, reward, done = rotate(
                env, task,
                target_quat=target_quat,
                max_steps=120, threshold=0.05, timeout=10.0
            )
            if done:
                print("[Task] Terminated unexpectedly after rotate.")
                return
        except Exception:
            print("[Error] rotate step failed!")
            traceback.print_exc()
            return

        # Step-2 : move to drawer side position
        try:
            obs, reward, done = move(
                env, task,
                target_pos=bottom_side_pos,
                max_steps=150, threshold=0.01, timeout=10.0
            )
            if done:
                print("[Task] Terminated unexpectedly after move-to-side.")
                return
        except Exception:
            print("[Error] move-to-side step failed!")
            traceback.print_exc()
            return

        # Step-3 : move to drawer anchor (handle) position
        try:
            obs, reward, done = move(
                env, task,
                target_pos=bottom_anchor_pos,
                max_steps=150, threshold=0.01, timeout=10.0
            )
            if done:
                print("[Task] Terminated unexpectedly after move-to-anchor.")
                return
        except Exception:
            print("[Error] move-to-anchor step failed!")
            traceback.print_exc()
            return

        # Step-4 : pick the drawer handle (grasp)
        try:
            obs, reward, done = pick(
                env, task,
                target_pos=bottom_anchor_pos,
                approach_distance=0.10,
                max_steps=120, threshold=0.01,
                approach_axis='z', timeout=10.0
            )
            if done:
                print("[Task] Terminated unexpectedly after pick-drawer.")
                return
        except Exception:
            print("[Error] pick-drawer (handle) failed!")
            traceback.print_exc()
            return

        # Step-5 : pull the drawer open
        try:
            axis, distance = _infer_pull_params(bottom_anchor_pos, bottom_joint_pos)
            print(f"[Task] Pulling drawer along '{axis}' by {distance:.3f} m")
            obs, reward, done = pull(
                env, task,
                pull_distance=distance,
                pull_axis=axis,
                max_steps=150, threshold=0.01, timeout=10.0
            )
            if done:
                print("[Task] Terminated unexpectedly after pull.")
                return
        except Exception:
            print("[Error] pull step failed!")
            traceback.print_exc()
            return

        # Optional: release drawer handle in-place (open gripper)
        try:
            cur_pos = task.get_observation().gripper_pose[:3]
            obs, reward, done = place(
                env, task,
                target_pos=cur_pos,
                approach_distance=0.00,
                max_steps=1, threshold=0.001,
                approach_axis='z', timeout=2.0
            )
            if done:
                print("[Task] Terminated unexpectedly after releasing handle.")
                return
        except Exception:
            # Not critical – continue even if releasing fails
            print("[Warning] Drawer release failed – continuing.")

        # Step-6 : pick rubbish from table
        try:
            obs, reward, done = pick(
                env, task,
                target_pos=rubbish_pos,
                approach_distance=0.15,
                max_steps=150, threshold=0.01,
                approach_axis='z', timeout=10.0
            )
            if done:
                print("[Task] Terminated unexpectedly after picking rubbish.")
                return
        except Exception:
            print("[Error] pick rubbish failed!")
            traceback.print_exc()
            return

        # Step-7 : place rubbish into bin
        try:
            obs, reward, done = place(
                env, task,
                target_pos=bin_pos,
                approach_distance=0.15,
                max_steps=150, threshold=0.01,
                approach_axis='z', timeout=10.0
            )
            if done:
                print("[Task] Task completed successfully! Reward:", reward)
            else:
                print("[Task] Plan executed but task not flagged as done.")
        except Exception:
            print("[Error] place rubbish failed!")
            traceback.print_exc()
            return

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()