# run_skeleton_task.py
# ------------------------------------------------------------
# Executable task script that drives the robot through the
# oracle plan described in the specification.
# It relies exclusively on the predefined skills provided in
# `skill_code.py` and never re-implements low-level primitives.
# ------------------------------------------------------------

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor          # kept for completeness

from env import setup_environment, shutdown_environment

# === Import all predefined skills (DO NOT redefine them) ==========
from skill_code import rotate, move, pick, pull, place

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

from scipy.spatial.transform import Rotation as R


def _safe_position_lookup(name, cached):
    """Resolve a world-frame 3-D position for `name`.

    Priority:
        1) cached   – dictionary returned by `get_object_positions`
        2) direct   – query via `pyrep.objects.Shape`

    Returns
    -------
    np.ndarray, shape (3,)
    """
    if name in cached:
        return np.asarray(cached[name], dtype=float)

    try:
        return np.asarray(Shape(name).get_position(), dtype=float)
    except Exception:
        raise RuntimeError(f"[Task] Cannot resolve position for object '{name}'.")


def run_skeleton_task():
    """Main entry point executing the oracle plan with predefined skills."""
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------
    #  Environment initialisation
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        # Reset to the initial state
        descriptions, obs = task.reset()

        # Optional: start recording
        init_video_writers(obs)
        task.step = recording_step(task.step)          # wrap step
        task.get_observation = recording_get_observation(
            task.get_observation)                      # wrap get_observation

        # ------------------------------------------------------------------
        #  Fetch positions of all relevant objects / way-points
        # ------------------------------------------------------------------
        cached_positions = get_object_positions()

        # Drawer related points (bottom drawer)
        bottom_side_pos   = _safe_position_lookup('bottom_side_pos',   cached_positions)
        bottom_anchor_pos = _safe_position_lookup('bottom_anchor_pos', cached_positions)

        # Tomatoes (item1, item2) and the serving plate
        tomato1_pos = _safe_position_lookup('item1', cached_positions)   # corresponds to tomato1
        tomato2_pos = _safe_position_lookup('item2', cached_positions)   # corresponds to tomato2
        plate_pos   = _safe_position_lookup('plate', cached_positions)

        # ------------------------------------------------------------------
        #  Mapping for angle names → quaternion (xyzw)
        # ------------------------------------------------------------------
        angle_to_quat = {
            'zero_deg':   R.from_euler('z',   0, degrees=True).as_quat(),
            'ninety_deg': R.from_euler('z',  90, degrees=True).as_quat(),
        }

        # ------------------------------------------------------------------
        #  Execute oracle plan (specification)
        # ------------------------------------------------------------------
        done = False
        reward = 0.0

        # Step-1: rotate gripper from zero_deg → ninety_deg
        print("\n[Plan] Step-1  : rotate → ninety_deg")
        obs, reward, done = rotate(
            env, task,
            target_quat=angle_to_quat['ninety_deg']
        )
        if done:
            print("[Task] Terminated after rotate."); return

        # Step-2: move to side-pos-bottom
        print("\n[Plan] Step-2  : move → bottom_side_pos")
        obs, reward, done = move(
            env, task,
            target_pos=bottom_side_pos
        )
        if done:
            print("[Task] Terminated after move-to-side."); return

        # Step-3: move to anchor-pos-bottom (front of drawer handle)
        print("\n[Plan] Step-3  : move → bottom_anchor_pos")
        obs, reward, done = move(
            env, task,
            target_pos=bottom_anchor_pos
        )
        if done:
            print("[Task] Terminated after move-to-anchor."); return

        # Step-4: pick the drawer handle
        print("\n[Plan] Step-4  : pick (drawer handle)")
        obs, reward, done = pick(
            env, task,
            target_pos=bottom_anchor_pos,
            approach_axis='z'           # approach along +z (downwards)
        )
        if done:
            print("[Task] Terminated after pick-drawer."); return

        # Step-5: pull the drawer outward by 0.10 m along +x
        print("\n[Plan] Step-5  : pull drawer 0.10 m along +x")
        obs, reward, done = pull(
            env, task,
            pull_distance=0.10,
            pull_axis='x'
        )
        if done:
            print("[Task] Terminated after pull."); return

        # NOTE: We keep the gripper closed on the handle; the next pick()
        #       automatically opens the gripper as necessary.

        # --------------------------------------------------------------
        #  Handle first tomato (item1)
        # --------------------------------------------------------------
        print("\n[Plan] Step-6  : pick tomato1")
        obs, reward, done = pick(
            env, task,
            target_pos=tomato1_pos,
            approach_axis='z'
        )
        if done:
            print("[Task] Terminated after picking tomato1."); return

        print("\n[Plan] Step-7  : place tomato1 onto plate")
        obs, reward, done = place(
            env, task,
            target_pos=plate_pos,
            approach_axis='z'
        )
        if done:
            print("[Task] Terminated after placing tomato1."); return

        # --------------------------------------------------------------
        #  Handle second tomato (item2)
        # --------------------------------------------------------------
        print("\n[Plan] Step-8  : pick tomato2")
        obs, reward, done = pick(
            env, task,
            target_pos=tomato2_pos,
            approach_axis='z'
        )
        if done:
            print("[Task] Terminated after picking tomato2."); return

        print("\n[Plan] Step-9  : place tomato2 onto plate")
        obs, reward, done = place(
            env, task,
            target_pos=plate_pos,
            approach_axis='z'
        )
        if done:
            print(f"[Task] Completed! Final reward: {reward}")
        else:
            print("[Task] Plan finished, but environment reports done=False. "
                  "Task might need additional checking.")

    except Exception as exc:
        print(f"[Task] Exception encountered: {exc}")
        raise
    finally:
        # Always perform environment shutdown
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()