import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape        # retained – sometimes helpful for debugging
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *                     # gives rotate, move, pull, pick, place

from video import (init_video_writers,
                   recording_step,
                   recording_get_observation)

from object_positions import get_object_positions


def _safe_lookup(name: str, table: dict):
    """Utility – raise a clear error if the required key is missing."""
    if name not in table or table[name] is None:
        raise RuntimeError(f"[Task] Cannot find position for object '{name}'.")
    return table[name]


def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # -------------------------------------------------
    #  1) Environment setup
    # -------------------------------------------------
    env, task = setup_environment()
    try:
        # Reset the task
        descriptions, obs = task.reset()

        # Optional: video recording helpers
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -------------------------------------------------
        #  2) Query all relevant object positions from helper
        # -------------------------------------------------
        positions = get_object_positions()          # dict(name -> np.ndarray)

        bottom_side_pos   = _safe_lookup('bottom_side_pos',   positions)
        bottom_anchor_pos = _safe_lookup('bottom_anchor_pos', positions)
        tomato1_pos       = _safe_lookup('tomato1',           positions)
        tomato2_pos       = _safe_lookup('tomato2',           positions)
        plate_pos         = _safe_lookup('plate',             positions)

        # -------------------------------------------------
        #  3) Execute the oracle plan (Specification)
        # -------------------------------------------------
        # STEP-1  rotate gripper from 0° to 90°
        obs = task.get_observation()
        current_quat = obs.gripper_pose[3:7]
        # create a quaternion corresponding to +90° around Z
        target_quat = (R.from_quat(current_quat) *
                       R.from_euler('z', 90, degrees=True)).as_quat()

        # [Frozen Code Start]
        obs, reward, done = rotate(env, task, target_quat)
        # [Frozen Code End]

        if done:
            print("[Task] Finished early at Step-1.")
            return

        # STEP-2  move-to-side  → bottom_side_pos
        obs, reward, done = move(env, task, bottom_side_pos)
        if done:
            print("[Task] Finished early at Step-2.")
            return

        # STEP-3  move-to-anchor → bottom_anchor_pos
        obs, reward, done = move(env, task, bottom_anchor_pos)
        if done:
            print("[Task] Finished early at Step-3.")
            return

        # STEP-4  pick-drawer  (close gripper on drawer handle)
        # We approximate the drawer-specific “pick-drawer” with the generic
        # pick() skill – approach slightly closer along +Z (downwards)
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=bottom_anchor_pos,
                                 approach_distance=0.05,
                                 approach_axis='z')
        if done:
            print("[Task] Finished early at Step-4.")
            return

        # STEP-5  pull drawer 20 cm along +X
        obs, reward, done = pull(env,
                                 task,
                                 pull_distance=0.20,
                                 pull_axis='x')
        if done:
            print("[Task] Finished early at Step-5.")
            return

        # STEP-6  pick tomato1 from the table
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=tomato1_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Task] Finished early at Step-6.")
            return

        # STEP-7  place tomato1 on the plate
        obs, reward, done = place(env,
                                  task,
                                  target_pos=plate_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Task] Finished early at Step-7.")
            return

        # STEP-8  pick tomato2
        obs, reward, done = pick(env,
                                 task,
                                 target_pos=tomato2_pos,
                                 approach_distance=0.15,
                                 approach_axis='z')
        if done:
            print("[Task] Finished early at Step-8.")
            return

        # STEP-9  place tomato2 on the plate
        obs, reward, done = place(env,
                                  task,
                                  target_pos=plate_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Task] Finished early at Step-9.")
            return

        # -------------------------------------------------
        print("[Task] Oracle plan executed successfully!")

    except Exception as e:
        # Any uncaught error is logged; the environment will still shut down.
        print("[Task] Exception during execution:", e)

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()