# run_oracle_task.py
#
# This script instantiates the RLBench environment, then executes—step by step—
# the oracle plan described in the Specification.  All low‑level behaviour is
# delegated to the predefined skills imported from `skill_code.py`.
#
# NOTE:  No new skills are introduced; we only invoke the existing ones:
#        rotate, move, pull, pick, place
#
# Plan (mirrors the Specification):
#   1  rotate(gripper, zero_deg → ninety_deg)
#   2  move‑to‑side  (nowhere‑pos → side‑pos‑bottom)
#   3  move‑to‑anchor(side‑pos‑bottom → anchor‑pos‑bottom)
#   4  pick‑drawer   (grasp bottom drawer handle)
#   5  pull          (open the drawer)
#   6  pick tomato1 from table
#   7  place tomato1 on plate
#   8  pick tomato2 from table
#   9  place tomato2 on plate

import numpy as np
from scipy.spatial.transform import Rotation as R

from env import setup_environment, shutdown_environment
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions

# === Import all predefined skill functions ===
from skill_code import rotate, move, pull, pick, place


def quat_from_single_axis(axis: str, deg: float) -> np.ndarray:
    """Utility – return [x,y,z,w] quaternion for a rotation of `deg` about axis."""
    return R.from_euler(axis, deg, degrees=True).as_quat()


def fetch_pos(positions: dict, key: str):
    """Safe dictionary access with descriptive error if object is missing."""
    if key not in positions:
        raise KeyError(f"[run_oracle_task] Object '{key}' not found in positions dictionary.")
    return np.asarray(positions[key], dtype=np.float32)


def run_oracle_task():
    print("==========  RUN ORACLE TASK  ==========")

    # ------------------------------------------------------------------ #
    # 1)  Environment initialisation                                     #
    # ------------------------------------------------------------------ #
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # (Optional) enable video recording
        init_video_writers(obs)
        # Wrap task.step / task.get_observation so that every call is recorded
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------ #
        # 2)  Retrieve all relevant 3‑D positions from the environment        #
        # ------------------------------------------------------------------ #
        positions = get_object_positions()

        bottom_side_pos   = fetch_pos(positions, 'bottom_side_pos')
        bottom_anchor_pos = fetch_pos(positions, 'bottom_anchor_pos')
        tomato1_pos       = fetch_pos(positions, 'tomato1')
        tomato2_pos       = fetch_pos(positions, 'tomato2')
        plate_pos         = fetch_pos(positions, 'plate')

        # ------------------------------------------------------------------ #
        # 3)  Execute the oracle plan                                         #
        # ------------------------------------------------------------------ #
        done = False
        reward = 0.0

        # Step 1  – rotate gripper from 0° to 90° (around z‑axis)
        print("\n---  Step 1 : rotate gripper ---")
        target_quat = quat_from_single_axis('z', 90.0)
        obs, reward, done = rotate(env, task, target_quat)
        if done: raise RuntimeError("[Task] Terminated during rotate.")

        # Step 2  – move to the drawer’s side position
        print("\n---  Step 2 : move to side‑pos‑bottom ---")
        obs, reward, done = move(env, task, bottom_side_pos)
        if done: raise RuntimeError("[Task] Terminated during move‑to‑side.")

        # Step 3  – move to the drawer’s anchor (handle) position
        print("\n---  Step 3 : move to anchor‑pos‑bottom ---")
        obs, reward, done = move(env, task, bottom_anchor_pos)
        if done: raise RuntimeError("[Task] Terminated during move‑to‑anchor.")

        # Step 4  – pick‑drawer (grab the handle)
        print("\n---  Step 4 : pick drawer handle ---")
        obs, reward, done = pick(env, task, bottom_anchor_pos,
                                 approach_distance=0.10, approach_axis='-z')
        if done: raise RuntimeError("[Task] Terminated during pick‑drawer.")

        # Step 5  – pull the drawer open (along +x by 0.20 m)
        print("\n---  Step 5 : pull drawer ---")
        obs, reward, done = pull(env, task, pull_distance=0.20, pull_axis='x')
        if done: raise RuntimeError("[Task] Terminated during pull.")

        # After opening, retreat slightly to avoid collision with tomatoes
        safe_above_drawer = bottom_anchor_pos + np.array([0, 0, 0.10])
        obs, reward, done = move(env, task, safe_above_drawer)
        if done: raise RuntimeError("[Task] Terminated while retreating.")

        # Step 6  – pick tomato1 from table
        print("\n---  Step 6 : pick tomato1 ---")
        obs, reward, done = pick(env, task, tomato1_pos,
                                 approach_distance=0.12, approach_axis='z')
        if done: raise RuntimeError("[Task] Terminated during pick tomato1.")

        # Step 7  – place tomato1 onto plate
        print("\n---  Step 7 : place tomato1 ---")
        obs, reward, done = place(env, task, plate_pos,
                                  approach_distance=0.12, approach_axis='z')
        if done: raise RuntimeError("[Task] Terminated during place tomato1.")

        # Step 8  – pick tomato2 from table
        print("\n---  Step 8 : pick tomato2 ---")
        obs, reward, done = pick(env, task, tomato2_pos,
                                 approach_distance=0.12, approach_axis='z')
        if done: raise RuntimeError("[Task] Terminated during pick tomato2.")

        # Step 9  – place tomato2 onto plate
        print("\n---  Step 9 : place tomato2 ---")
        obs, reward, done = place(env, task, plate_pos,
                                  approach_distance=0.12, approach_axis='z')
        if done:
            print("[Task] Finished after final place.  Reward:", reward)
        else:
            print("[Task] Plan executed.  Task may continue accumulating reward.")

    except Exception as e:
        print("[run_oracle_task] ERROR:", e)
    finally:
        shutdown_environment(env)
        print("==========  END ORACLE TASK  ==========")


if __name__ == "__main__":
    run_oracle_task()
