# run_oracle_plan.py
#
# This script instantiates the RLBench environment, records a video (optional),
# and executes the exact oracle plan given in the specification using only the
# predefined skill functions (rotate, move, pick, pull, place).  The code is
# written to be as defensive as possible: if an object name cannot be found in
# the position dictionary we raise an informative error so that the problem can
# be fixed easily.

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape          # required by the skeleton
from pyrep.objects.proximity_sensor import ProximitySensor   # required by the skeleton

from env import setup_environment, shutdown_environment

# Import *only* the predefined skills – no new skills are implemented here
from skill_code import rotate, move, pick, pull, place

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# -----------------------------------------------------------------------------


def _fetch_pos(positions, *candidate_names):
    """
    Convenience helper: returns the first position that exists in `positions`
    from the list of candidate names.  If none of them are present we raise
    a KeyError with a detailed message.
    """
    for name in candidate_names:
        if name in positions:
            return positions[name]
    raise KeyError(
        f"None of the candidate names {candidate_names} were found in the "
        f"positions dictionary.  Available keys: {list(positions.keys())}"
    )


def run_oracle_plan():
    """Execute the oracle plan step‑by‑step."""
    print("\n=================  START ORACLE PLAN  =================")

    # ---------------------------------------------------------------------
    #  Environment initialisation
    # ---------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional: start video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -----------------------------------------------------------------
        #  Retrieve all object positions that we will need
        # -----------------------------------------------------------------
        positions = get_object_positions()

        # Drawer related positions
        middle_side_pos   = _fetch_pos(positions, 'middle_side_pos',  'side_pos_middle')
        middle_anchor_pos = _fetch_pos(positions, 'middle_anchor_pos', 'anchor_pos_middle')

        # Tomatoes & plate
        tomato1_pos = _fetch_pos(positions, 'tomato1')
        tomato2_pos = _fetch_pos(positions, 'tomato2')
        plate_pos   = _fetch_pos(positions, 'plate')

        # -----------------------------------------------------------------
        #  Oracle plan (specification‑compliant)
        # -----------------------------------------------------------------
        done = False   # RLBench “done” flag

        # STEP 1: rotate gripper from zero_deg → ninety_deg
        #         We assume “ninety_deg” means +90° about the z‑axis.
        print("\n[STEP 1] rotate gripper to +90° about Z")
        target_quat = R.from_euler('z', 90, degrees=True).as_quat()  # xyzw
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Aborted] task signalled done after STEP 1")
            return

        # STEP 2: move gripper to the drawer’s side position
        print("\n[STEP 2] move to middle drawer – side position")
        obs, reward, done = move(env, task, target_pos=middle_side_pos)
        if done:
            print("[Aborted] task signalled done after STEP 2")
            return

        # STEP 3: move from side position → anchor (handle) position
        print("\n[STEP 3] move to middle drawer – anchor/handle position")
        obs, reward, done = move(env, task, target_pos=middle_anchor_pos)
        if done:
            print("[Aborted] task signalled done after STEP 3")
            return

        # STEP 4: pick (grab) the drawer handle
        print("\n[STEP 4] pick / grasp drawer handle")
        obs, reward, done = pick(env, task, target_pos=middle_anchor_pos)
        if done:
            print("[Aborted] task signalled done after STEP 4")
            return

        # STEP 5: pull the drawer open
        #         We pull 0.15 m along the +x axis.  Adjust sign if required.
        print("\n[STEP 5] pull drawer by 0.15 m along +X")
        obs, reward, done = pull(env, task, pull_distance=0.15, pull_axis='x')
        if done:
            print("[Aborted] task signalled done after STEP 5")
            return

        # -----------------------------------------------------------------
        #  Tomatoes → plate
        # -----------------------------------------------------------------
        # STEP 6: pick tomato 1
        print("\n[STEP 6] pick tomato 1")
        obs, reward, done = pick(env, task, target_pos=tomato1_pos)
        if done:
            print("[Aborted] task signalled done after STEP 6")
            return

        # STEP 7: place tomato 1 on the plate
        print("\n[STEP 7] place tomato 1 on plate")
        obs, reward, done = place(env, task, target_pos=plate_pos)
        if done:
            print("[Aborted] task signalled done after STEP 7")
            return

        # STEP 8: pick tomato 2
        print("\n[STEP 8] pick tomato 2")
        obs, reward, done = pick(env, task, target_pos=tomato2_pos)
        if done:
            print("[Aborted] task signalled done after STEP 8")
            return

        # STEP 9: place tomato 2 on the plate
        print("\n[STEP 9] place tomato 2 on plate")
        obs, reward, done = place(env, task, target_pos=plate_pos)
        if done:
            print("[Success] Oracle plan completed and task signalled done.")
        else:
            print("[Info] Oracle plan executed – task did not set done flag.")

    finally:
        shutdown_environment(env)
        print("==================  END ORACLE PLAN  ==================\n")


# -----------------------------------------------------------------------------


if __name__ == "__main__":
    run_oracle_plan()
