import numpy as np
from scipy.spatial.transform import Rotation as R

# RLBench / simulation helpers
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# Low‑level skills supplied in skill_code
from skill_code import rotate, move, pick, pull, place


def _safe_numpy(pos_like):
    """Utility: convert RLBench position/tuple/list to numpy array."""
    return np.asarray(pos_like, dtype=np.float32)


def _determine_pull_axis(anchor, joint):
    """
    Decide which world axis corresponds to the pull direction of a drawer
    by looking at the vector from its joint (hinge) toward the handle (anchor).
    Returns (axis_string, distance) where axis_string is one of
    'x', '-x', 'y', '-y', 'z', '-z'.
    """
    diff = anchor - joint
    idx = int(np.argmax(np.abs(diff)))
    axis_names = ['x', 'y', 'z']
    sign = np.sign(diff[idx])
    axis = axis_names[idx]
    axis_string = axis if sign >= 0 else f'-{axis}'
    return axis_string, float(np.linalg.norm(diff))


def run_skeleton_task():
    """Main entry for our combined drawer‑opening & tomato‑placing task."""
    print("===== Starting Skeleton Task =====")

    # 1) ------------------------------------------------------------------
    #   Environment initialisation
    # --------------------------------------------------------------------
    env, task = setup_environment()
    try:
        _, obs = task.reset()

        # optional video capture
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # 2) --------------------------------------------------------------
        #   Query object positions from helper
        # --------------------------------------------------------------
        positions = get_object_positions()  # dict: name -> [x,y,z]
        # Drawer we will manipulate: choose the BOTTOM drawer
        side_pos   = _safe_numpy(positions['bottom_side_pos'])
        anchor_pos = _safe_numpy(positions['bottom_anchor_pos'])
        joint_pos  = _safe_numpy(positions['bottom_joint_pos'])

        tomato1_pos = _safe_numpy(positions['tomato1'])
        tomato2_pos = _safe_numpy(positions['tomato2'])
        plate_pos   = _safe_numpy(positions['plate'])

        # 3) --------------------------------------------------------------
        #   Step‑by‑step execution following the specification
        # --------------------------------------------------------------
        # ---- Step 1: rotate  (zero -> 90 deg about Z) -------------------
        target_quat = R.from_euler('xyz', [0, 0, np.pi / 2]).as_quat()
        print("\n[PLAN] Step 1/9  – Rotate gripper 90 deg about Z")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[TASK] Finished unexpectedly after rotate.")
            return

        # ---- Step 2: move to the drawer side position -------------------
        print("\n[PLAN] Step 2/9  – Move to drawer side position")
        obs, reward, done = move(env, task, side_pos)
        if done:
            print("[TASK] Finished unexpectedly after move‑to‑side.")
            return

        # ---- Step 3: move to the drawer anchor (handle) -----------------
        print("\n[PLAN] Step 3/9  – Move to drawer anchor position")
        obs, reward, done = move(env, task, anchor_pos)
        if done:
            print("[TASK] Finished unexpectedly after move‑to‑anchor.")
            return

        # ---- Step 4: pick the drawer handle -----------------------------
        print("\n[PLAN] Step 4/9  – Grasp drawer handle")
        obs, reward, done = pick(env, task, target_pos=anchor_pos,
                                 approach_distance=0.10,
                                 approach_axis='-z')
        if done:
            print("[TASK] Finished unexpectedly after pick‑drawer.")
            return

        # ---- Step 5: pull the drawer open -------------------------------
        print("\n[PLAN] Step 5/9  – Pull the drawer")
        pull_axis, pull_distance = _determine_pull_axis(anchor_pos, joint_pos)
        # Add some margin to ensure it opens sufficiently
        obs, reward, done = pull(env, task,
                                 pull_distance=pull_distance + 0.05,
                                 pull_axis=pull_axis)
        if done:
            print("[TASK] Finished unexpectedly after pull.")
            return

        # Helper: pick & place tomato routine -----------------------------
        def _transfer_tomato(tomato_position, idx):
            # Step numbers according to overall plan
            print(f"\n[PLAN] Step {idx}/9 – Pick tomato")
            obs, reward, done_local = pick(env, task,
                                           target_pos=tomato_position,
                                           approach_distance=0.12,
                                           approach_axis='-z')
            if done_local:
                return True  # signal overall done
            print(f"[PLAN] Step {idx+1}/9 – Place tomato on plate")
            offset = np.array([0.0, 0.0, 0.02])  # small height offset
            obs2, reward2, done_local2 = place(env, task,
                                               target_pos=plate_pos + offset,
                                               approach_distance=0.12,
                                               approach_axis='-z')
            return done_local2

        # ---- Step 6 & 7: tomato1 ---------------------------------------
        if _transfer_tomato(tomato1_pos, 6):
            print("[TASK] Finished unexpectedly during tomato1 transfer.")
            return

        # ---- Step 8 & 9: tomato2 ---------------------------------------
        if _transfer_tomato(tomato2_pos, 8):
            print("[TASK] Finished unexpectedly during tomato2 transfer.")
            return

        print("\n===== TASK COMPLETED SUCCESSFULLY! =====")

    finally:
        # Ensure simulator shutdown even on errors
        shutdown_environment(env)
        print("===== Environment Shutdown =====")


if __name__ == "__main__":
    run_skeleton_task()