import numpy as np
from pyrep.objects.shape import Shape            # kept – may be useful for debugging
from pyrep.objects.proximity_sensor import ProximitySensor
import traceback

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# --- Pre-implemented skills (use them exactly as provided) --------------------
from skill_code import move, pick, place, rotate, pull


def _fetch_pos(positions_dict, key_aliases):
    """
    Utility that returns the first existing position for a list of candidate
    keys (aliases).  Raises KeyError if none of the aliases are found.
    """
    for k in key_aliases:
        if k in positions_dict:
            return np.asarray(positions_dict[k], dtype=float)
    raise KeyError(f"None of the aliases {key_aliases} found in positions_dict.")


def run_skeleton_task():
    """Execute the oracle plan that (1) opens the bottom drawer and
       (2) places both tomatoes on the plate."""
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()

    try:
        # ------------------------------------------------------------------
        #  Initial reset & (optional) video recording
        # ------------------------------------------------------------------
        descriptions, obs = task.reset()
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        #  Retrieve all known object positions (once is sufficient)
        # ------------------------------------------------------------------
        positions = get_object_positions()
        print("[DEBUG] Available position keys:", list(positions.keys()))

        # --- Drawer related positions -------------------------------------------------
        bottom_side_pos   = _fetch_pos(positions,
                                       ['bottom_side_pos', 'side_pos_bottom',
                                        'bottom_side', 'side-pos-bottom'])
        bottom_anchor_pos = _fetch_pos(positions,
                                       ['bottom_anchor_pos', 'anchor_pos_bottom',
                                        'bottom_anchor', 'anchor-pos-bottom'])

        # --- Plate & tomatoes ---------------------------------------------------------
        plate_pos   = _fetch_pos(positions, ['plate'])
        tomato1_pos = _fetch_pos(positions, ['tomato1', 'item1'])
        tomato2_pos = _fetch_pos(positions, ['tomato2', 'item2'])

        # ------------------------------------------------------------------
        #  ORACLE PLAN  (Specification compliant)
        # ------------------------------------------------------------------

        # STEP-1  rotate gripper → ninety_deg about global Z
        print("[PLAN] Step-1  rotate gripper to 90° about Z")
        # quaternion (x,y,z,w) for +90° around Z
        target_quat = np.array([0.0, 0.0,
                                np.sin(np.deg2rad(90 / 2)),
                                np.cos(np.deg2rad(90 / 2))])
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[PLAN] Finished early during rotate.")
            return

        # STEP-2  move-to-side of bottom drawer
        print("[PLAN] Step-2  move to bottom_side_pos:", bottom_side_pos)
        obs, reward, done = move(env, task, bottom_side_pos)
        if done:
            print("[PLAN] Finished early during move-to-side.")
            return

        # STEP-3  move-to-anchor (handle centre)
        print("[PLAN] Step-3  move to bottom_anchor_pos:", bottom_anchor_pos)
        obs, reward, done = move(env, task, bottom_anchor_pos)
        if done:
            print("[PLAN] Finished early during move-to-anchor.")
            return

        # STEP-4  pick-drawer (close gripper on the handle)
        print("[PLAN] Step-4  grasp drawer handle (pick-drawer)")
        obs, reward, done = pick(env, task,
                                 target_pos=bottom_anchor_pos,
                                 approach_distance=0.10,
                                 approach_axis='z')
        if done:
            print("[PLAN] Finished early while grasping drawer.")
            return

        # STEP-5  pull the drawer open (approx. 12 cm along −x)
        print("[PLAN] Step-5  pull the drawer open")
        obs, reward, done = pull(env, task,
                                 pull_distance=0.12,
                                 pull_axis='-x')
        if done:
            print("[PLAN] Finished early during pull.")
            return

        # After pull, lift slightly for safe travel
        print("[PLAN] Safety lift after opening drawer")
        lift_target = obs.gripper_pose[:3] + np.array([0.0, 0.0, 0.15])
        obs, reward, done = move(env, task, lift_target)
        if done:
            print("[PLAN] Finished early during safety lift.")
            return

        # ------------------------------------------------------------------
        #  Helper function : pick a tomato and place on plate
        # ------------------------------------------------------------------
        def pick_and_place(tomato_name, tomato_world_pos):
            """Pick one tomato and place it on the plate."""
            print(f"[PLAN] Pick {tomato_name} at {tomato_world_pos}")
            obs_local, reward_local, done_local = pick(
                env, task,
                target_pos=tomato_world_pos,
                approach_distance=0.15,
                approach_axis='z')
            if done_local:
                return done_local

            # Lift tomato to avoid collisions en-route
            lift_pos = obs_local.gripper_pose[:3] + np.array([0.0, 0.0, 0.15])
            obs_local, _, done_local = move(env, task, lift_pos)
            if done_local:
                return done_local

            print(f"[PLAN] Place {tomato_name} on plate at {plate_pos}")
            obs_local, _, done_local = place(
                env, task,
                target_pos=plate_pos,
                approach_distance=0.15,
                approach_axis='z')
            return done_local

        # STEP-6 & 7  handle tomato1
        done = pick_and_place('tomato1', tomato1_pos)
        if done:
            print("[PLAN] Task ended after tomato1.")
            return

        # STEP-8 & 9  handle tomato2
        done = pick_and_place('tomato2', tomato2_pos)
        if done:
            print("[PLAN] Task ended after tomato2.")
            return

        print("[RESULT] Oracle plan executed successfully – drawer opened and tomatoes on plate!")

    except Exception as e:
        print("[ERROR] Exception during task execution:")
        traceback.print_exc()
    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()