# run_oracle_task.py
#
# Oracle plan (exactly matching the Specification):
#
#   1) rotate – turn gripper from zero_deg to ninety_deg
#   2) move   – gripper to side-pos-bottom
#   3) move   – gripper to anchor-pos-bottom
#   4) pick   – close gripper on the anchor handle (grasp the drawer)
#   5) pull   – pull the (unlocked) bottom drawer open
#   6) pick   – grasp tomato 1
#   7) place  – drop tomato 1 onto plate
#   8) pick   – grasp tomato 2
#   9) place  – drop tomato 2 onto plate

import traceback
import numpy as np
from scipy.spatial.transform import Rotation as R

# -----------------------------------------------------------
#  Imports required by the provided skeleton – DO NOT REMOVE
# -----------------------------------------------------------
from env import setup_environment, shutdown_environment
from skill_code import rotate, move, pull, pick, place
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions
from pyrep.objects.shape import Shape                      # kept from skeleton
from pyrep.objects.proximity_sensor import ProximitySensor  # kept from skeleton
# -----------------------------------------------------------


# -----------------------------------------------------------
#  Helper utilities
# -----------------------------------------------------------
def quaternion_from_euler(roll, pitch, yaw):
    """Return an xyzw quaternion given Euler angles (rad)."""
    return R.from_euler('xyz', [roll, pitch, yaw]).as_quat()


def safe_call(fn, *args, **kwargs):
    """Execute a skill while printing a traceback on exception so that
       the environment can still shut down cleanly."""
    try:
        return fn(*args, **kwargs)
    except Exception as exc:
        print(f"[Error] Exception during {fn.__name__}: {exc}")
        traceback.print_exc()
        raise


# -----------------------------------------------------------
#  Main oracle routine
# -----------------------------------------------------------
def run_oracle_task():
    print("=====  Oracle Task – Open Drawer & Put Tomatoes on Plate  =====")

    # ----------------------------------------------------------
    #  0) Environment initialisation
    # ----------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # --- optional video recording (kept from skeleton) ---
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # ------------------------------------------------------
        #  1) Query all required object positions
        # ------------------------------------------------------
        positions = get_object_positions()

        bottom_side_pos   = positions.get('bottom_side_pos')
        bottom_anchor_pos = positions.get('bottom_anchor_pos')
        plate_pos         = positions.get('plate')
        tomato1_pos       = positions.get('item1')   # tomato 1
        tomato2_pos       = positions.get('item2')   # tomato 2

        # Sanity-check that every position we need exists
        missing = [name for (name, pos) in
                   [('bottom_side_pos', bottom_side_pos),
                    ('bottom_anchor_pos', bottom_anchor_pos),
                    ('plate', plate_pos),
                    ('item1', tomato1_pos),
                    ('item2', tomato2_pos)]
                   if pos is None]
        if missing:
            raise RuntimeError(f"[Init] Missing positions for: {missing}")

        # Convert to NumPy arrays (skills expect np.ndarray)
        bottom_side_pos   = np.asarray(bottom_side_pos, dtype=float)
        bottom_anchor_pos = np.asarray(bottom_anchor_pos, dtype=float)
        plate_pos         = np.asarray(plate_pos, dtype=float)
        tomato1_pos       = np.asarray(tomato1_pos, dtype=float)
        tomato2_pos       = np.asarray(tomato2_pos, dtype=float)

        # ------------------------------------------------------
        #  2) Execute the 9-step oracle plan
        # ------------------------------------------------------
        done = False

        # Step-1  rotate gripper 90° about z-axis
        target_quat = quaternion_from_euler(0.0, 0.0, np.pi / 2.0)   # 90°
        print("\n[Plan-1] rotate gripper to 90° about z-axis")
        obs, reward, done = safe_call(rotate, env, task, target_quat)
        if done:
            print("[Early-End] Task finished unexpectedly after rotate.")
            return

        # Step-2  move to side-handle position
        print(f"\n[Plan-2] move to bottom_side_pos {bottom_side_pos}")
        obs, reward, done = safe_call(move, env, task, target_pos=bottom_side_pos)
        if done:
            print("[Early-End] Task finished unexpectedly after move-1.")
            return

        # Step-3  move to anchor-handle position
        print(f"\n[Plan-3] move to bottom_anchor_pos {bottom_anchor_pos}")
        obs, reward, done = safe_call(move, env, task, target_pos=bottom_anchor_pos)
        if done:
            print("[Early-End] Task finished unexpectedly after move-2.")
            return

        # Step-4  pick the drawer handle (close gripper on anchor-pos)
        print(f"\n[Plan-4] pick drawer handle at {bottom_anchor_pos}")
        obs, reward, done = safe_call(pick, env, task, target_pos=bottom_anchor_pos)
        if done:
            print("[Early-End] Task finished unexpectedly after pick-handle.")
            return

        # Step-5  pull drawer – along +x by 0.15 m
        print("\n[Plan-5] pull bottom drawer 0.15 m along +x")
        obs, reward, done = safe_call(pull, env, task,
                                      pull_distance=0.15,
                                      pull_axis='x')
        if done:
            print("[Early-End] Task finished unexpectedly after pull.")
            return

        # Step-6  pick tomato 1
        print(f"\n[Plan-6] pick tomato1 (item1) at {tomato1_pos}")
        obs, reward, done = safe_call(pick, env, task, target_pos=tomato1_pos)
        if done:
            print("[Early-End] Task finished unexpectedly after pick-1.")
            return

        # Step-7  place tomato 1 onto plate
        print(f"\n[Plan-7] place tomato1 on plate at {plate_pos}")
        obs, reward, done = safe_call(place, env, task, target_pos=plate_pos)
        if done:
            print("[Early-End] Task finished unexpectedly after place-1.")
            return

        # Step-8  pick tomato 2
        print(f"\n[Plan-8] pick tomato2 (item2) at {tomato2_pos}")
        obs, reward, done = safe_call(pick, env, task, target_pos=tomato2_pos)
        if done:
            print("[Early-End] Task finished unexpectedly after pick-2.")
            return

        # Step-9  place tomato 2 onto plate
        print(f"\n[Plan-9] place tomato2 on plate at {plate_pos}")
        obs, reward, done = safe_call(place, env, task, target_pos=plate_pos)

        # ------------------------------------------------------
        #  3) Final report
        # ------------------------------------------------------
        if done:
            print(f"\n[Success] Goal accomplished! Final reward: {reward}")
        else:
            print("\n[Info] Plan executed – environment reports done=False. "
                  "Depending on task definition this may simply mean the "
                  "episode continues even after the goal is satisfied.")

    finally:
        # Always shut down the simulator
        shutdown_environment(env)
        print("=====  Oracle Task finished – environment shut down  =====")


# -----------------------------------------------------------
#  Entry point
# -----------------------------------------------------------
if __name__ == "__main__":
    run_oracle_task()