# run_oracle_plan.py

import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Import only the skills we actually use
from skill_code import pick, place, move, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def safe_skill_call(skill_fn, *args, **kwargs):
    """
    Wrapper for safely calling an RLBench skill.
    Any RuntimeError raised by the skill will immediately shut down
    the environment and re‑raise, so that the outer try/finally still
    guarantees a clean shutdown.
    """
    env = args[0]         # env is always the first positional argument
    try:
        return skill_fn(*args, **kwargs)
    except Exception as e:
        from env import shutdown_environment
        shutdown_environment(env)
        raise e


def run_skeleton_task():
    print("===== Starting Oracle Plan Execution =====")

    # ------------------------------------------------------------------
    #  Environment setup
    # ------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # Optional video recording helpers
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        #  Retrieve useful object positions from the scene
        # ------------------------------------------------------------------
        positions = get_object_positions()

        # Mandatory objects according to the provided object list
        bottom_side_pos     = np.array(positions['bottom_side_pos'])
        bottom_anchor_pos   = np.array(positions['bottom_anchor_pos'])
        tomato1_pos         = np.array(positions['tomato1'])
        tomato2_pos         = np.array(positions['tomato2'])
        plate_pos           = np.array(positions['plate'])

        # ------------------------------------------------------------------
        #  Execute the oracle plan (Specification.steps)
        # ------------------------------------------------------------------

        # STEP 1 : rotate gripper from zero_deg to ninety_deg (90° about Z)
        obs = task.get_observation()
        current_quat = obs.gripper_pose[3:7]
        rot_90_z = R.from_euler('z', 90, degrees=True)
        target_quat = (R.from_quat(current_quat) * rot_90_z).as_quat()
        print("[Plan] Step 1 – rotate gripper 90° about Z")
        obs, reward, done = safe_skill_call(rotate, env, task, target_quat)
        if done: return

        # STEP 2 : move‑to‑side (bottom drawer side position)
        print("[Plan] Step 2 – move gripper to bottom_side_pos")
        obs, reward, done = safe_skill_call(move, env, task, bottom_side_pos)
        if done: return

        # STEP 3 : move‑to‑anchor (handle position)
        print("[Plan] Step 3 – move gripper to bottom_anchor_pos")
        obs, reward, done = safe_skill_call(move, env, task, bottom_anchor_pos)
        if done: return

        # STEP 4 : pick‑drawer  (close the gripper on the drawer handle)
        print("[Plan] Step 4 – grasp drawer handle (pick‑drawer)")
        obs, reward, done = safe_skill_call(
            pick,
            env,
            task,
            target_pos=bottom_anchor_pos,
            approach_distance=0.10,
            max_steps=80,
            threshold=0.008,
            approach_axis='z',
            timeout=8.0
        )
        if done: return

        # STEP 5 : pull the drawer outward (assume +x is pull axis, 0.20 m)
        print("[Plan] Step 5 – pull drawer outward")
        obs, reward, done = safe_skill_call(
            pull,
            env,
            task,
            pull_distance=0.20,
            pull_axis='x',
            max_steps=120,
            threshold=0.008,
            timeout=10.0
        )
        if done: return

        # ------------------------------------------------------------------
        #  Transfer tomatoes to the plate
        # ------------------------------------------------------------------
        # STEP 6 : pick tomato1
        print("[Plan] Step 6 – pick tomato1")
        obs, reward, done = safe_skill_call(
            pick,
            env,
            task,
            target_pos=tomato1_pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done: return

        # STEP 7 : place tomato1 on plate
        print("[Plan] Step 7 – place tomato1 on plate")
        obs, reward, done = safe_skill_call(
            place,
            env,
            task,
            target_pos=plate_pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done: return

        # STEP 8 : pick tomato2
        print("[Plan] Step 8 – pick tomato2")
        obs, reward, done = safe_skill_call(
            pick,
            env,
            task,
            target_pos=tomato2_pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done: return

        # STEP 9 : place tomato2 on plate
        print("[Plan] Step 9 – place tomato2 on plate")
        obs, reward, done = safe_skill_call(
            place,
            env,
            task,
            target_pos=plate_pos,
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done: return

        print("===== Oracle Plan finished successfully! =====")

    finally:
        # Always ensure the environment is properly shut down
        shutdown_environment(env)


if __name__ == "__main__":
    run_skeleton_task()