# run_oracle_plan.py
#
# Description:
#  - Executes the oracle plan given in the specification for the “combined‑domain”
#    task (open a drawer, then dispose of the rubbish).
#  - Uses ONLY the already‑implemented skills imported from skill_code.
#  - Relies on the object_positions helper to obtain world‑space positions.

import numpy as np
from math import pi
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pull, pick, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
#                               Helper Functions                              #
# --------------------------------------------------------------------------- #
def fetch_position(positions, *possible_keys):
    """Return the first position that exists in positions among possible_keys."""
    for key in possible_keys:
        if key in positions:
            return positions[key]
    raise KeyError(f"None of the keys {possible_keys} found in object positions.")


def run_oracle_plan():
    print("===== [Oracle Plan] Start =====")
    env, task = setup_environment()          # RL‑Bench environment
    try:
        # ------------------------------------------------------------------- #
        #                    Environment / Recording Initialisation           #
        # ------------------------------------------------------------------- #
        descriptions, obs = task.reset()
        init_video_writers(obs)

        # Wrap step / get_observation so that every frame is recorded
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------- #
        #                     Retrieve Relevant Object Poses                  #
        # ------------------------------------------------------------------- #
        positions = get_object_positions()

        # Drawer related positions
        bottom_side_pos   = fetch_position(positions, 'bottom_side_pos',  'side-pos-bottom')
        bottom_anchor_pos = fetch_position(positions, 'bottom_anchor_pos', 'anchor-pos-bottom')

        # Generic “nowhere” waypoint that the gripper starts from
        nowhere_pos       = fetch_position(positions, 'waypoint1',         'nowhere_pos')

        # Objects for disposal
        rubbish_pos       = fetch_position(positions, 'rubbish')
        bin_pos           = fetch_position(positions, 'bin')

        # ------------------------------------------------------------------- #
        #                             Plan Execution                          #
        # ------------------------------------------------------------------- #
        done = False

        # STEP 1 – rotate gripper from 0° to 90° about the Z axis
        # -------------------------------------------------------
        # Compute a quaternion that is rotated 90° about the world‑Z axis
        current_quat = obs.gripper_pose[3:7]                  # xyzw
        rot_90_z     = R.from_euler('z', 90, degrees=True).as_quat()  # xyzw
        # Compose the rotation: target = rot_90_z * current
        target_quat  = R.from_quat(rot_90_z) * R.from_quat(current_quat)
        target_quat  = target_quat.as_quat()                  # xyzw

        print("[Plan] Step 1 – rotate gripper 90° about Z")
        obs, reward, done = rotate(env, task, target_quat)
        if done:
            print("[Plan] Episode ended after rotate.")
            return

        # STEP 2 – move to the drawer side position
        # -----------------------------------------
        print("[Plan] Step 2 – move to bottom_side_pos")
        obs, reward, done = move(env, task, bottom_side_pos)
        if done:
            print("[Plan] Episode ended after move‑to‑side.")
            return

        # STEP 3 – move to the drawer anchor (handle) position
        # ----------------------------------------------------
        print("[Plan] Step 3 – move to bottom_anchor_pos")
        obs, reward, done = move(env, task, bottom_anchor_pos)
        if done:
            print("[Plan] Episode ended after move‑to‑anchor.")
            return

        # STEP 4 – “pick‑drawer” (grasp the handle)
        # -----------------------------------------
        # The generic pick skill is sufficient: it closes the gripper on the handle.
        print("[Plan] Step 4 – grasp the drawer handle (pick‑drawer)")
        obs, reward, done = pick(env, task, bottom_anchor_pos,
                                 approach_distance=0.08,
                                 approach_axis='y')
        if done:
            print("[Plan] Episode ended after pick‑drawer.")
            return

        # STEP 5 – pull to open the drawer
        # --------------------------------
        # Pull straight along the X axis by 0.20 m (tune if necessary).
        print("[Plan] Step 5 – pull the drawer open")
        obs, reward, done = pull(env, task,
                                 pull_distance=0.20,
                                 pull_axis='x')
        if done:
            print("[Plan] Episode ended after pull.")
            return

        # STEP 6 – pick the rubbish from the table
        # ----------------------------------------
        print("[Plan] Step 6 – pick rubbish on the table")
        obs, reward, done = pick(env, task, rubbish_pos,
                                 approach_distance=0.10,
                                 approach_axis='z')
        if done:
            print("[Plan] Episode ended while picking rubbish.")
            return

        # STEP 7 – place rubbish into the bin
        # -----------------------------------
        print("[Plan] Step 7 – place rubbish into bin")
        obs, reward, done = place(env, task, bin_pos,
                                  approach_distance=0.15,
                                  approach_axis='z')
        if done:
            print("[Plan] Task finished after placing rubbish. Reward:", reward)
        else:
            print("[Plan] Task completed (done is False but final step executed).")

    except Exception as e:
        print("Exception during oracle execution:", str(e))
        raise
    finally:
        shutdown_environment(env)
        print("===== [Oracle Plan] End =====")


if __name__ == "__main__":
    run_oracle_plan()
