import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Import all predefined skills
from skill_code import move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def _safe_get_position(name, cache):
    """Return a 3‑D world position for a named Shape.

    Priority:
        1) cached dictionary returned by get_object_positions()
        2) direct lookup via pyrep Shape API
    """
    if name in cache:
        return np.asarray(cache[name], dtype=float)
    try:
        return np.asarray(Shape(name).get_position(), dtype=float)
    except Exception as e:
        raise RuntimeError(f"[Error] Cannot obtain position for object '{name}': {e}")


def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------
    #  Environment set‑up
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional: start video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)          # wrap step for recording
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        #  Retrieve initial object positions
        # ------------------------------------------------------------------
        positions_cache = get_object_positions()

        # Drawer related positions
        side_pos_bottom   = _safe_get_position('bottom_side_pos',   positions_cache)
        anchor_pos_bottom = _safe_get_position('bottom_anchor_pos', positions_cache)

        # Plate position
        plate_pos = _safe_get_position('plate', positions_cache)

        # Tomato (item) positions – treat item1 / item2 as tomatoes
        tomato_names = ['tomato1', 'tomato2', 'item1', 'item2']   # try both naming schemes
        tomato_positions = []
        for name in tomato_names:
            try:
                tomato_positions.append((_safe_get_position(name, positions_cache), name))
            except RuntimeError:
                continue
        if len(tomato_positions) < 2:
            raise RuntimeError("[Error] Could not locate at least two tomatoes/items in the scene.")
        # Keep only first two unique ones
        tomato_positions = tomato_positions[:2]

        # ------------------------------------------------------------------
        #  Step‑by‑step execution of the given oracle plan
        # ------------------------------------------------------------------
        done = False
        reward = 0.0

        # Step 1: move gripper to the drawer’s side position
        if not done:
            print("[Plan] Step 1 – move to drawer side position")
            obs, reward, done = move(env, task, target_pos=side_pos_bottom)

        # Step 2: rotate gripper to 90 degrees (around the z‑axis)
        if not done:
            print("[Plan] Step 2 – rotate gripper 90° about z")
            target_quat = R.from_euler('z', 90.0, degrees=True).as_quat()  # xyzw
            obs, reward, done = rotate(env, task, target_quat=target_quat)

        # Step 3: move to drawer anchor (handle) position
        if not done:
            print("[Plan] Step 3 – move to drawer anchor position")
            obs, reward, done = move(env, task, target_pos=anchor_pos_bottom)

        # Step 4: pull the drawer open by 0.20 m along +x
        if not done:
            print("[Plan] Step 4 – pull drawer")
            obs, reward, done = pull(env, task, pull_distance=0.20, pull_axis='x')

        # Steps 5‑12: pick two tomatoes and place them on the plate
        for idx, (tomato_pos, tomato_name) in enumerate(tomato_positions, start=1):
            if done:
                break

            # Re‑query tomato position before each pick (in case the object moved)
            try:
                tomato_pos = _safe_get_position(tomato_name, get_object_positions())
            except RuntimeError:
                # If object is already grasped (first loop), skip re‑fetch
                pass

            print(f"[Plan] Tomato {idx}: move above {tomato_name}")
            obs, reward, done = move(env, task, target_pos=tomato_pos)
            if done:
                break

            print(f"[Plan] Tomato {idx}: pick {tomato_name}")
            obs, reward, done = pick(env, task, target_pos=tomato_pos)
            if done:
                break

            print(f"[Plan] Tomato {idx}: move to plate")
            obs, reward, done = move(env, task, target_pos=plate_pos)
            if done:
                break

            print(f"[Plan] Tomato {idx}: place on plate")
            obs, reward, done = place(env, task, target_pos=plate_pos)
            if done:
                break

        # ------------------------------------------------------------------
        #  Final status
        # ------------------------------------------------------------------
        if done:
            print(f"[Task] Task terminated early. Final reward: {reward}")
        else:
            print("[Task] Plan execution finished. Task may continue to evaluate success.")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
